mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 05:38:26 +00:00
Compare commits
8 Commits
codex/crit
...
codex/cove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7141bda563 | ||
|
|
74d4366994 | ||
|
|
c85c0585ab | ||
|
|
804630fa02 | ||
|
|
7417864353 | ||
|
|
1f26dc2b57 | ||
|
|
2ec2a0e091 | ||
|
|
9cf5c9a93f |
8
.github/workflows/ci-tests-unit.yaml
vendored
8
.github/workflows/ci-tests-unit.yaml
vendored
@@ -58,11 +58,3 @@ jobs:
|
||||
|
||||
- name: Enforce critical coverage gate
|
||||
run: pnpm test:coverage:critical
|
||||
|
||||
- name: Upload critical coverage summary
|
||||
if: always() && !cancelled() && hashFiles('coverage/coverage-summary.json') != ''
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: critical-coverage-summary
|
||||
path: coverage/coverage-summary.json
|
||||
retention-days: 1
|
||||
|
||||
26
.github/workflows/pr-report.yaml
vendored
26
.github/workflows/pr-report.yaml
vendored
@@ -2,11 +2,7 @@ name: 'PR: Unified Report'
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- 'CI: Size Data'
|
||||
- 'CI: Performance Report'
|
||||
- 'CI: E2E Coverage'
|
||||
- 'CI: Tests Unit'
|
||||
workflows: ['CI: Size Data', 'CI: Performance Report', 'CI: E2E Coverage']
|
||||
types:
|
||||
- completed
|
||||
branches-ignore:
|
||||
@@ -94,25 +90,6 @@ jobs:
|
||||
path: temp/coverage
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Find critical coverage workflow run
|
||||
if: steps.pr-meta.outputs.skip != 'true'
|
||||
id: find-critical-coverage
|
||||
uses: ./.github/actions/find-workflow-run
|
||||
with:
|
||||
workflow-id: ci-tests-unit.yaml
|
||||
head-sha: ${{ steps.pr-meta.outputs.head-sha }}
|
||||
not-found-status: skip
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Download critical coverage summary
|
||||
if: steps.pr-meta.outputs.skip != 'true' && (steps.find-critical-coverage.outputs.status == 'ready' || steps.find-critical-coverage.outputs.status == 'failed')
|
||||
uses: dawidd6/action-download-artifact@0bd50d53a6d7fb5cb921e607957e9cc12b4ce392 # v12
|
||||
with:
|
||||
name: critical-coverage-summary
|
||||
run_id: ${{ steps.find-critical-coverage.outputs.run-id }}
|
||||
path: temp/critical-coverage
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Download perf metrics (current)
|
||||
if: steps.pr-meta.outputs.skip != 'true' && steps.find-perf.outputs.status == 'ready'
|
||||
uses: dawidd6/action-download-artifact@0bd50d53a6d7fb5cb921e607957e9cc12b4ce392 # v12
|
||||
@@ -152,7 +129,6 @@ jobs:
|
||||
--size-status=${{ steps.find-size.outputs.status }}
|
||||
--perf-status=${{ steps.find-perf.outputs.status }}
|
||||
--coverage-status=${{ steps.find-coverage.outputs.status }}
|
||||
--critical-coverage-status=${{ steps.find-critical-coverage.outputs.status }}
|
||||
> pr-report.md
|
||||
|
||||
- name: Remove legacy separate comments
|
||||
|
||||
@@ -15,7 +15,7 @@ const { categories } = defineProps<{
|
||||
|
||||
const activeSection = ref(categories[0]?.value ?? '')
|
||||
|
||||
const HEADER_OFFSET = -144
|
||||
const HEADER_OFFSET_PX = -144
|
||||
const BOTTOM_THRESHOLD_PX = 4
|
||||
const SCROLL_SAFETY_MS = 1500
|
||||
|
||||
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
|
||||
const el = document.getElementById(id)
|
||||
if (el) {
|
||||
scrollTo(el, {
|
||||
offset: HEADER_OFFSET,
|
||||
offset: HEADER_OFFSET_PX,
|
||||
duration: 0.8,
|
||||
immediate: prefersReducedMotion(),
|
||||
onComplete: clearScrollLock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<li
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
|
||||
>
|
||||
<slot />
|
||||
</li>
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
formatCoverageMetric,
|
||||
renderCriticalCoverageReport
|
||||
} from './unified-report'
|
||||
|
||||
describe('formatCoverageMetric', () => {
|
||||
it('formats covered counts and percent', () => {
|
||||
expect(
|
||||
formatCoverageMetric({
|
||||
covered: 8,
|
||||
total: 10,
|
||||
pct: 80
|
||||
})
|
||||
).toBe('8/10 | 80.00%')
|
||||
})
|
||||
|
||||
it('falls back when a metric is missing or invalid', () => {
|
||||
expect(formatCoverageMetric()).toBe('N/A | N/A')
|
||||
expect(
|
||||
formatCoverageMetric({
|
||||
covered: Number.NaN,
|
||||
total: 10,
|
||||
pct: 80
|
||||
})
|
||||
).toBe('N/A | N/A')
|
||||
})
|
||||
})
|
||||
|
||||
describe('renderCriticalCoverageReport', () => {
|
||||
it('renders critical coverage rows from a summary', () => {
|
||||
expect(
|
||||
renderCriticalCoverageReport({
|
||||
total: {
|
||||
statements: { covered: 8, total: 10, pct: 80 },
|
||||
functions: { covered: 3, total: 4, pct: 75 },
|
||||
lines: { covered: 9, total: 10, pct: 90 }
|
||||
}
|
||||
})
|
||||
).toBe(
|
||||
[
|
||||
'## Critical Unit Coverage',
|
||||
'',
|
||||
'| Metric | Covered | Coverage |',
|
||||
'|---|---:|---:|',
|
||||
'| Statements | 8/10 | 80.00% |',
|
||||
'| Branches | N/A | N/A |',
|
||||
'| Functions | 3/4 | 75.00% |',
|
||||
'| Lines | 9/10 | 90.00% |'
|
||||
].join('\n')
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,5 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import { existsSync, readFileSync } from 'node:fs'
|
||||
import { pathToFileURL } from 'node:url'
|
||||
import { existsSync } from 'node:fs'
|
||||
|
||||
const args: string[] = process.argv.slice(2)
|
||||
|
||||
@@ -13,58 +12,6 @@ function getArg(name: string): string | undefined {
|
||||
const sizeStatus = getArg('size-status') ?? 'pending'
|
||||
const perfStatus = getArg('perf-status') ?? 'pending'
|
||||
const coverageStatus = getArg('coverage-status') ?? 'skip'
|
||||
const criticalCoverageStatus = getArg('critical-coverage-status') ?? 'skip'
|
||||
|
||||
export type CoverageMetricName =
|
||||
| 'statements'
|
||||
| 'branches'
|
||||
| 'functions'
|
||||
| 'lines'
|
||||
export type CoverageMetric = {
|
||||
total: number
|
||||
covered: number
|
||||
pct: number
|
||||
}
|
||||
export type CoverageSummary = {
|
||||
total?: Partial<Record<CoverageMetricName, CoverageMetric>>
|
||||
}
|
||||
|
||||
export function formatCoverageMetric(metric?: CoverageMetric): string {
|
||||
if (
|
||||
!metric ||
|
||||
!Number.isFinite(metric.covered) ||
|
||||
!Number.isFinite(metric.total) ||
|
||||
!Number.isFinite(metric.pct)
|
||||
) {
|
||||
return 'N/A | N/A'
|
||||
}
|
||||
|
||||
return `${metric.covered}/${metric.total} | ${metric.pct.toFixed(2)}%`
|
||||
}
|
||||
|
||||
export function renderCriticalCoverageReport(
|
||||
summary: CoverageSummary = JSON.parse(
|
||||
readFileSync('temp/critical-coverage/coverage-summary.json', 'utf-8')
|
||||
) as CoverageSummary
|
||||
): string {
|
||||
const rows: Array<[string, CoverageMetricName]> = [
|
||||
['Statements', 'statements'],
|
||||
['Branches', 'branches'],
|
||||
['Functions', 'functions'],
|
||||
['Lines', 'lines']
|
||||
]
|
||||
|
||||
return [
|
||||
'## Critical Unit Coverage',
|
||||
'',
|
||||
'| Metric | Covered | Coverage |',
|
||||
'|---|---:|---:|',
|
||||
...rows.map(([label, key]) => {
|
||||
const metric = summary.total?.[key]
|
||||
return `| ${label} | ${formatCoverageMetric(metric)} |`
|
||||
})
|
||||
].join('\n')
|
||||
}
|
||||
|
||||
const lines: string[] = []
|
||||
|
||||
@@ -150,41 +97,4 @@ if (coverageStatus === 'ready' && existsSync('temp/coverage/coverage.lcov')) {
|
||||
lines.push('> ⚠️ Coverage collection failed. Check the CI workflow logs.')
|
||||
}
|
||||
|
||||
if (
|
||||
(criticalCoverageStatus === 'ready' || criticalCoverageStatus === 'failed') &&
|
||||
existsSync('temp/critical-coverage/coverage-summary.json')
|
||||
) {
|
||||
try {
|
||||
lines.push('')
|
||||
lines.push(renderCriticalCoverageReport())
|
||||
} catch {
|
||||
lines.push('')
|
||||
lines.push('## Critical Unit Coverage')
|
||||
lines.push('')
|
||||
lines.push(
|
||||
'> Failed to render critical coverage summary. Check the CI workflow logs.'
|
||||
)
|
||||
}
|
||||
} else if (criticalCoverageStatus === 'ready') {
|
||||
lines.push('')
|
||||
lines.push('## Critical Unit Coverage')
|
||||
lines.push('')
|
||||
lines.push('> Critical coverage summary unavailable.')
|
||||
} else if (criticalCoverageStatus === 'failed') {
|
||||
lines.push('')
|
||||
lines.push('## Critical Unit Coverage')
|
||||
lines.push('')
|
||||
lines.push('> Critical coverage gate failed. Check the CI workflow logs.')
|
||||
} else if (criticalCoverageStatus === 'pending') {
|
||||
lines.push('')
|
||||
lines.push('## Critical Unit Coverage')
|
||||
lines.push('')
|
||||
lines.push('> Critical coverage gate is still running.')
|
||||
}
|
||||
|
||||
if (
|
||||
process.argv[1] &&
|
||||
import.meta.url === pathToFileURL(process.argv[1]).href
|
||||
) {
|
||||
process.stdout.write(lines.join('\n') + '\n')
|
||||
}
|
||||
process.stdout.write(lines.join('\n') + '\n')
|
||||
|
||||
75
src/base/common/async.test.ts
Normal file
75
src/base/common/async.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
describe('runWhenGlobalIdle', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('falls back to a timeout when idle callbacks are unavailable', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner)
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).toHaveBeenCalledOnce()
|
||||
const deadline = runner.mock.calls[0][0]
|
||||
expect(deadline.didTimeout).toBe(true)
|
||||
expect(deadline.timeRemaining()).toBeGreaterThanOrEqual(0)
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
})
|
||||
|
||||
it('cancels fallback idle work before it runs', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.stubGlobal('requestIdleCallback', undefined)
|
||||
vi.stubGlobal('cancelIdleCallback', undefined)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner).dispose()
|
||||
await vi.runAllTimersAsync()
|
||||
|
||||
expect(runner).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses native idle callbacks when available', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 42)
|
||||
const cancelIdleCallback = vi.fn()
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', cancelIdleCallback)
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
const disposable = runWhenGlobalIdle(runner, 250)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, { timeout: 250 })
|
||||
|
||||
disposable.dispose()
|
||||
disposable.dispose()
|
||||
|
||||
expect(cancelIdleCallback).toHaveBeenCalledOnce()
|
||||
expect(cancelIdleCallback).toHaveBeenCalledWith(42)
|
||||
})
|
||||
|
||||
it('omits native idle timeout options when no timeout is supplied', async () => {
|
||||
const requestIdleCallback = vi.fn(() => 7)
|
||||
vi.stubGlobal('requestIdleCallback', requestIdleCallback)
|
||||
vi.stubGlobal('cancelIdleCallback', vi.fn())
|
||||
const { runWhenGlobalIdle } = await import('./async')
|
||||
const runner = vi.fn()
|
||||
|
||||
runWhenGlobalIdle(runner)
|
||||
|
||||
expect(requestIdleCallback).toHaveBeenCalledWith(runner, undefined)
|
||||
})
|
||||
})
|
||||
@@ -122,6 +122,22 @@ describe('downloadUtil', () => {
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for an empty URL', () => {
|
||||
expect(() => downloadFile('')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws for a whitespace URL', () => {
|
||||
expect(() => downloadFile(' ')).toThrow(
|
||||
'Invalid URL provided for download'
|
||||
)
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
expect(createObjectURLSpy).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should prefer custom filename over extracted filename', () => {
|
||||
const testUrl =
|
||||
'https://example.com/api/file?filename=extracted-image.jpg'
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
CREDITS_PER_USD,
|
||||
COMFY_CREDIT_RATE_CENTS,
|
||||
centsToCredits,
|
||||
clampUsd,
|
||||
creditsToCents,
|
||||
creditsToUsd,
|
||||
formatCredits,
|
||||
@@ -43,4 +44,21 @@ describe('comfyCredits helpers', () => {
|
||||
expect(formatCreditsFromUsd({ usd: 1, locale })).toBe('211.00')
|
||||
expect(formatUsd({ value: 4.2, locale })).toBe('4.20')
|
||||
})
|
||||
|
||||
test('formats with compatible fraction digit bounds', () => {
|
||||
expect(
|
||||
formatCredits({
|
||||
value: 12.345,
|
||||
locale: 'en-US',
|
||||
numberOptions: { minimumFractionDigits: 4, maximumFractionDigits: 2 }
|
||||
})
|
||||
).toBe('12.35')
|
||||
})
|
||||
|
||||
test('clamps USD purchase values into the supported range', () => {
|
||||
expect(clampUsd(Number.NaN)).toBe(0)
|
||||
expect(clampUsd(-5)).toBe(1)
|
||||
expect(clampUsd(42)).toBe(42)
|
||||
expect(clampUsd(5000)).toBe(1000)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
|
||||
const subscriptionDialog = useSubscriptionDialog()
|
||||
|
||||
function handleClick() {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComputedRef, Ref } from 'vue'
|
||||
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type {
|
||||
BillingStatus,
|
||||
@@ -75,9 +76,10 @@ export interface BillingActions {
|
||||
*/
|
||||
requireActiveSubscription: () => Promise<void>
|
||||
/**
|
||||
* Shows the subscription dialog.
|
||||
* Shows the subscription dialog. Pass a reason so the paywall open and any
|
||||
* downstream checkout stay attributed to the triggering product moment.
|
||||
*/
|
||||
showSubscriptionDialog: () => void
|
||||
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
|
||||
}
|
||||
|
||||
export interface BillingState {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getTierFeatures
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
PreviewSubscribeOptions,
|
||||
SubscribeOptions
|
||||
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
|
||||
return activeContext.value.requireActiveSubscription()
|
||||
}
|
||||
|
||||
function showSubscriptionDialog() {
|
||||
return activeContext.value.showSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
|
||||
return activeContext.value.showSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
|
||||
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingStatus,
|
||||
BillingSubscriptionStatus,
|
||||
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
legacyShowSubscriptionDialog()
|
||||
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
legacyShowSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
legacyShowSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,6 @@ function handleClose() {
|
||||
}
|
||||
|
||||
function handleSubscribe() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
|
||||
// Shows loading affordances
|
||||
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
|
||||
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'team_700',
|
||||
'yearly'
|
||||
'yearly',
|
||||
{ paymentIntentSource: 'deep_link' }
|
||||
)
|
||||
// Team never goes through the personal checkout path
|
||||
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()
|
||||
|
||||
@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
return
|
||||
}
|
||||
isTeamCheckout.value = true
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle)
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle, {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
if (isActiveSubscription.value) {
|
||||
await accessBillingPortal(undefined, false)
|
||||
} else {
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
}
|
||||
}, reportError)
|
||||
|
||||
|
||||
@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
|
||||
})
|
||||
|
||||
function handleAddCredits() {
|
||||
telemetry?.trackAddApiCreditButtonClicked()
|
||||
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
function handleUpgradeToAddCredits() {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
}
|
||||
|
||||
async function handleWindowFocus() {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import FreeTierDialogContent from './FreeTierDialogContent.vue'
|
||||
|
||||
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
|
||||
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
function renderComponent() {
|
||||
function renderComponent(props?: { reason?: PaymentIntentSource }) {
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
@@ -23,6 +25,7 @@ function renderComponent() {
|
||||
})
|
||||
|
||||
return render(FreeTierDialogContent, {
|
||||
props,
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
|
||||
renderComponent()
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the generic copy for intent reasons outside the credits variants', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'subscribe_to_run' })
|
||||
expect(
|
||||
screen.getByText('Your credits refresh on Jul 15, 2026.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('swaps to the out-of-credits copy without the refresh line', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="!reason || reason === 'subscription_required'"
|
||||
v-if="!isCreditsBlockedVariant"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -65,10 +65,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="
|
||||
(!reason || reason === 'subscription_required') &&
|
||||
formattedRenewalDate
|
||||
"
|
||||
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -88,7 +85,7 @@
|
||||
@click="$emit('upgrade')"
|
||||
>
|
||||
{{
|
||||
reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
isCreditsBlockedVariant
|
||||
? $t('subscription.freeTier.upgradeCta')
|
||||
: $t('subscription.freeTier.subscribeCta')
|
||||
}}
|
||||
@@ -103,12 +100,12 @@ import { computed } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
|
||||
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
|
||||
defineProps<{
|
||||
reason?: SubscriptionDialogReason
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
|
||||
})
|
||||
|
||||
const freeTierCredits = computed(() => getTierCredits('free'))
|
||||
|
||||
// Only these two variants replace the generic free-tier copy; any other
|
||||
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
|
||||
const isCreditsBlockedVariant = computed(
|
||||
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -261,6 +261,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use the latest userId value when it changes after mount', async () => {
|
||||
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,13 +277,19 @@ import type {
|
||||
TierKey,
|
||||
TierPricing
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import {
|
||||
recordPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
|
||||
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
@@ -321,6 +327,10 @@ interface PricingTierConfig {
|
||||
isPopular?: boolean
|
||||
}
|
||||
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
chooseTeamWorkspace: []
|
||||
}>()
|
||||
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
} as const
|
||||
const previousPlan = currentPlanDescriptor.value
|
||||
const checkoutAttribution = await getCheckoutAttributionForCloud()
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
})
|
||||
}
|
||||
const beginCheckoutMetadata = userId.value
|
||||
? {
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change' as const,
|
||||
...(reason ? { payment_intent_source: reason } : {}),
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
}
|
||||
: null
|
||||
// Pass the target tier to create a deep link to subscription update confirmation
|
||||
const checkoutTier = getCheckoutTier(
|
||||
targetPlan.tierKey,
|
||||
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
|
||||
if (downgrade) {
|
||||
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
|
||||
await accessBillingPortal()
|
||||
const didOpenPortal = await accessBillingPortal()
|
||||
if (didOpenPortal && beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
|
||||
}
|
||||
} else {
|
||||
const didOpenPortal = await accessBillingPortal(checkoutTier)
|
||||
if (!didOpenPortal) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
payment_intent_source: reason,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
|
||||
...(previousPlan
|
||||
? { previous_cycle: previousPlan.billingCycle }
|
||||
: {})
|
||||
})
|
||||
if (beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
beginCheckoutMetadata,
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await performSubscriptionCheckout(
|
||||
tierKey,
|
||||
currentBillingCycle.value,
|
||||
true
|
||||
)
|
||||
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
|
||||
paymentIntentSource: reason
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
|
||||
@@ -56,7 +56,7 @@ const handleSubscribe = () => {
|
||||
current_tier: tier.value?.toLowerCase()
|
||||
})
|
||||
isAwaitingStripeSubscription.value = true
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
|
||||
trackRunButton({ subscribe_to_run: true })
|
||||
}
|
||||
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -48,7 +48,9 @@
|
||||
v-if="isActiveSubscription"
|
||||
variant="primary"
|
||||
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="showSubscriptionDialog"
|
||||
@click="
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.upgradePlan') }}
|
||||
</Button>
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
</i18n-t>
|
||||
</div>
|
||||
|
||||
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
|
||||
<PricingTable
|
||||
:reason
|
||||
class="flex-1"
|
||||
@choose-team-workspace="handleChooseTeam"
|
||||
/>
|
||||
|
||||
<!-- Contact and Enterprise Links -->
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
const { onClose, reason, onChooseTeam } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
onChooseTeam?: () => void
|
||||
}>()
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
|
||||
)
|
||||
return
|
||||
case 'subscription':
|
||||
void dialogService.showSubscriptionRequiredDialog()
|
||||
void dialogService.showSubscriptionRequiredDialog({
|
||||
reason: 'subscription_required'
|
||||
})
|
||||
return
|
||||
case 'credits':
|
||||
void dialogService.showTopUpCreditsDialog({
|
||||
|
||||
@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
describe('usePricingTableUrlLoader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
reason: 'deep_link',
|
||||
planMode: undefined
|
||||
})
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
reason: 'deep_link'
|
||||
})
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
})
|
||||
|
||||
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('denies, strips, and clears together when the user is not eligible', async () => {
|
||||
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({
|
||||
query: { other: 'param' }
|
||||
})
|
||||
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
)
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
|
||||
'pricing'
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
mergePreservedQueryIntoQuery
|
||||
} from '@/platform/navigation/preservedQueryManager'
|
||||
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
|
||||
const planMode =
|
||||
param === 'team' || param === 'personal' ? param : undefined
|
||||
|
||||
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
|
||||
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
|
||||
})
|
||||
}, reportError)
|
||||
|
||||
const showSubscriptionDialog = (options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) => {
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
reason: options?.reason
|
||||
})
|
||||
|
||||
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
|
||||
void showSubscriptionRequiredDialog(options)
|
||||
}
|
||||
|
||||
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
|
||||
await fetchSubscriptionStatus()
|
||||
|
||||
if (!isSubscribedOrIsNotCloud.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
|
||||
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
|
||||
const {
|
||||
mockIsCloud,
|
||||
mockTrackHelpResourceClicked,
|
||||
mockTrackAddApiCreditButtonClicked
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockTrackHelpResourceClicked: vi.fn()
|
||||
mockTrackHelpResourceClicked: vi.fn(),
|
||||
mockTrackAddApiCreditButtonClicked: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () =>
|
||||
mockIsCloud.value
|
||||
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
|
||||
? {
|
||||
trackHelpResourceClicked: mockTrackHelpResourceClicked,
|
||||
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
|
||||
}
|
||||
: null
|
||||
}))
|
||||
|
||||
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
|
||||
const { handleAddApiCredits } = useSubscriptionActions()
|
||||
handleAddApiCredits()
|
||||
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
|
||||
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
|
||||
})
|
||||
|
||||
const handleAddApiCredits = () => {
|
||||
telemetry?.trackAddApiCreditButtonClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockShowLayoutDialog = vi.fn()
|
||||
const mockShowTeamWorkspacesDialog = vi.fn()
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isFreeTier: mockIsFreeTier,
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan,
|
||||
tier: mockTier
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
|
||||
useWorkspaceUI: () => ({
|
||||
permissions: {
|
||||
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsCloud.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
|
||||
const props = mockShowLayoutDialog.mock.calls[0][0].props
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the caller reason and current tier', () => {
|
||||
mockTier.value = 'STANDARD'
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
current_tier: 'standard',
|
||||
reason: 'upgrade_to_add_credits'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'subscribe_to_run' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not track on non-cloud', () => {
|
||||
mockIsCloud.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('show', () => {
|
||||
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
|
||||
expect.objectContaining({ key: 'subscription-required' })
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the reason for the free-tier dialog', () => {
|
||||
mockIsFreeTier.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { show } = useSubscriptionDialog()
|
||||
|
||||
show({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'out_of_credits' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startTeamWorkspaceUpgradeFlow', () => {
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
|
||||
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
|
||||
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
|
||||
|
||||
export type SubscriptionDialogReason =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
|
||||
interface SubscriptionDialogOptions {
|
||||
reason?: SubscriptionDialogReason
|
||||
export interface SubscriptionDialogOptions {
|
||||
reason?: PaymentIntentSource
|
||||
/**
|
||||
* Forces the unified pricing dialog to open on a specific plan tab,
|
||||
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
|
||||
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
|
||||
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
|
||||
}
|
||||
|
||||
// Fired here — the choke point every paywall/pricing dialog variant passes
|
||||
// through — so both the legacy and workspace billing paths emit it.
|
||||
function trackModalOpened(reason?: PaymentIntentSource) {
|
||||
// Resolved lazily to avoid the useBillingContext import cycle (see below).
|
||||
const { tier } = useBillingContext()
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
reason
|
||||
})
|
||||
}
|
||||
|
||||
function showPricingTable(options?: SubscriptionDialogOptions) {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
|
||||
return
|
||||
}
|
||||
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
// Shared dialog shell styling for both variants.
|
||||
const dialogComponentProps = {
|
||||
style: 'width: min(1328px, 95vw); max-height: 958px;',
|
||||
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
|
||||
// (not at composable setup) to avoid the useBillingContext import cycle.
|
||||
const { isFreeTier } = useBillingContext()
|
||||
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
const component = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
|
||||
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
|
||||
sessionStorage.removeItem(RESUME_PRICING_KEY)
|
||||
|
||||
if (!workspaceStore.isInPersonalWorkspace) {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'team_upgrade_resume' })
|
||||
}
|
||||
} catch {
|
||||
// sessionStorage may be unavailable
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
clearPendingSubscriptionCheckoutAttempt,
|
||||
consumePendingSubscriptionCheckoutSuccess,
|
||||
recordPendingSubscriptionCheckoutAttempt
|
||||
} from './subscriptionCheckoutTracker'
|
||||
|
||||
const activeProStatus = {
|
||||
is_active: true,
|
||||
subscription_tier: 'PRO',
|
||||
subscription_duration: 'MONTHLY'
|
||||
} as const
|
||||
|
||||
describe('subscriptionCheckoutTracker', () => {
|
||||
beforeEach(() => {
|
||||
clearPendingSubscriptionCheckoutAttempt()
|
||||
})
|
||||
|
||||
it('round-trips payment_intent_source from attempt to success metadata', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).toMatchObject({
|
||||
tier: 'pro',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('omits payment_intent_source when the attempt had none', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).not.toBeNull()
|
||||
expect(metadata).not.toHaveProperty('payment_intent_source')
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,12 @@ import type {
|
||||
TierKey
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
BeginCheckoutMetadata,
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType,
|
||||
SubscriptionSuccessMetadata
|
||||
} from '@/platform/telemetry/types'
|
||||
|
||||
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
|
||||
const VALID_TIER_KEYS = new Set<TierKey>([
|
||||
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
|
||||
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
|
||||
'comfy:subscription-checkout-attempt-changed'
|
||||
|
||||
type CheckoutType = 'new' | 'change'
|
||||
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
|
||||
|
||||
interface SubscriptionStatusSnapshot {
|
||||
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
|
||||
subscription_duration?: SubscriptionDuration | null
|
||||
}
|
||||
|
||||
interface PendingSubscriptionCheckoutAttempt {
|
||||
export interface PendingSubscriptionCheckoutAttempt {
|
||||
attempt_id: string
|
||||
started_at_ms: number
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface RecordPendingSubscriptionCheckoutAttemptInput {
|
||||
interface PendingSubscriptionCheckoutAttemptInput {
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
const dispatchPendingCheckoutChangeEvent = () => {
|
||||
@@ -168,6 +174,9 @@ const normalizeAttempt = (
|
||||
...(candidate.previous_cycle === 'monthly' ||
|
||||
candidate.previous_cycle === 'yearly'
|
||||
? { previous_cycle: candidate.previous_cycle }
|
||||
: {}),
|
||||
...(typeof candidate.payment_intent_source === 'string'
|
||||
? { payment_intent_source: candidate.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
|
||||
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
|
||||
getPendingSubscriptionCheckoutAttempt() !== null
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: RecordPendingSubscriptionCheckoutAttemptInput
|
||||
export const createPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
const attempt: PendingSubscriptionCheckoutAttempt = {
|
||||
return {
|
||||
attempt_id: createAttemptId(),
|
||||
started_at_ms: Date.now(),
|
||||
tier: input.tier,
|
||||
cycle: input.cycle,
|
||||
checkout_type: input.checkout_type,
|
||||
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
|
||||
...(input.payment_intent_source
|
||||
? { payment_intent_source: input.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
|
||||
export const persistPendingSubscriptionCheckoutAttempt = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
if (!storage) {
|
||||
return attempt
|
||||
}
|
||||
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
return attempt
|
||||
}
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt =>
|
||||
persistPendingSubscriptionCheckoutAttempt(
|
||||
createPendingSubscriptionCheckoutAttempt(input)
|
||||
)
|
||||
|
||||
export const withPendingCheckoutAttemptId = (
|
||||
metadata: BeginCheckoutMetadata,
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): BeginCheckoutMetadata => ({
|
||||
...metadata,
|
||||
checkout_attempt_id: attempt.attempt_id
|
||||
})
|
||||
|
||||
const didAttemptSucceed = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt,
|
||||
status: SubscriptionStatusSnapshot
|
||||
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
|
||||
cycle: attempt.cycle,
|
||||
checkout_type: attempt.checkout_type,
|
||||
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
|
||||
...(attempt.payment_intent_source
|
||||
? { payment_intent_source: attempt.payment_intent_source }
|
||||
: {}),
|
||||
value,
|
||||
currency: 'USD',
|
||||
ecommerce: {
|
||||
|
||||
@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'yearly', true)
|
||||
await performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
ga_client_id: 'ga-client-id',
|
||||
ga_session_id: 'ga-session-id',
|
||||
ga_session_number: 'ga-session-number',
|
||||
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
|
||||
gbraid: 'gbraid-456',
|
||||
wbraid: 'wbraid-789'
|
||||
})
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
JSON.parse(storedAttempt).attempt_id
|
||||
)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'/customers/cloud-subscription-checkout/pro-yearly'
|
||||
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[SubscriptionCheckout] Failed to collect checkout attribution',
|
||||
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
.spyOn(window, 'open')
|
||||
.mockImplementation(() => window as unknown as Window)
|
||||
|
||||
vi.mocked(global.fetch).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', {
|
||||
paymentIntentSource: 'out_of_credits'
|
||||
})
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
|
||||
)
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
const pendingAttempt = JSON.parse(storedAttempt)
|
||||
expect(pendingAttempt).toMatchObject({
|
||||
payment_intent_source: 'out_of_credits'
|
||||
})
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
pendingAttempt.attempt_id
|
||||
)
|
||||
openSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the latest userId when it changes after checkout starts', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
mockUserId.value = 'user-late'
|
||||
authHeader.resolve({ Authorization: 'Bearer test-token' })
|
||||
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-late',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
|
||||
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
const storedAttempt = window.localStorage.getItem(
|
||||
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
|
||||
)
|
||||
expect(storedAttempt).toBeNull()
|
||||
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { getComfyApiBaseUrl } from '@/config/comfyApi'
|
||||
import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import {
|
||||
createPendingSubscriptionCheckoutAttempt,
|
||||
persistPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
type CheckoutTier = TierKey | `${TierKey}-yearly`
|
||||
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
|
||||
return getCheckoutAttribution()
|
||||
}
|
||||
|
||||
interface PerformSubscriptionCheckoutOptions {
|
||||
openInNewTab?: boolean
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Core subscription checkout logic shared between PricingTable and
|
||||
* SubscriptionRedirectView. Handles:
|
||||
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
|
||||
export async function performSubscriptionCheckout(
|
||||
tierKey: TierKey,
|
||||
currentBillingCycle: BillingCycle,
|
||||
openInNewTab: boolean = true
|
||||
options: PerformSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
const { openInNewTab = true, paymentIntentSource } = options
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { userId } = storeToRefs(authStore)
|
||||
const telemetry = useTelemetry()
|
||||
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
|
||||
const data = await response.json()
|
||||
|
||||
if (data.checkout_url) {
|
||||
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...checkoutAttribution
|
||||
})
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
{
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {}),
|
||||
...checkoutAttribution
|
||||
},
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (openInNewTab) {
|
||||
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
|
||||
if (!checkoutWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
} else {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
globalThis.location.href = data.checkout_url
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn()
|
||||
}))
|
||||
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
|
||||
vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
|
||||
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
|
||||
workspaceApi: { subscribe: mockSubscribe }
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
|
||||
}))
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
|
||||
|
||||
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
billing_op_id: 'op_1'
|
||||
})
|
||||
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly', {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
|
||||
returnUrl: 'https://app.test/payment/success',
|
||||
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
teamCreditStopId: 'team_700'
|
||||
})
|
||||
expect(assignedHref).toBe('https://stripe.test/pay')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'team',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op_1',
|
||||
payment_intent_source: 'deep_link'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
|
||||
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
expect(assignedHref).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not track begin_checkout when subscribe fails', async () => {
|
||||
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
|
||||
|
||||
await expect(
|
||||
performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
).rejects.toThrow('subscribe failed')
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does nothing off cloud', async () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
|
||||
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
interface PerformTeamSubscriptionCheckoutOptions {
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
|
||||
* link: subscribes to the per-credit Team plan at the chosen slider stop and
|
||||
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
|
||||
*/
|
||||
export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId: string,
|
||||
billingCycle: BillingCycle
|
||||
billingCycle: BillingCycle,
|
||||
options: PerformTeamSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId
|
||||
})
|
||||
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType: 'new',
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource: options.paymentIntentSource
|
||||
})
|
||||
|
||||
if (response.status === 'needs_payment_method') {
|
||||
// A needs_payment_method response without a URL is unusable: surface it to
|
||||
// the caller's error handling rather than silently dropping the user home
|
||||
|
||||
@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
|
||||
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
|
||||
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
|
||||
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
|
||||
const b: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const metadata = {
|
||||
user_id: 'user-1',
|
||||
tier: 'pro' as const,
|
||||
cycle: 'monthly' as const,
|
||||
checkout_type: 'new' as const,
|
||||
payment_intent_source: 'subscribe_to_run' as const
|
||||
}
|
||||
registry.trackBeginCheckout(metadata)
|
||||
|
||||
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
|
||||
})
|
||||
|
||||
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
|
||||
const provider: TelemetryProvider = {
|
||||
trackAddApiCreditButtonClicked: vi.fn()
|
||||
}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(provider)
|
||||
|
||||
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(
|
||||
provider.trackAddApiCreditButtonClicked
|
||||
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
|
||||
})
|
||||
|
||||
it('skips providers that do not implement trackSearchQuery', () => {
|
||||
const empty: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackBeginCheckout({
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.BEGIN_CHECKOUT,
|
||||
{
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('captures add-credit clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'credits_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures share attribution events', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
EnterLinearMetadata,
|
||||
ShareFlowMetadata,
|
||||
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
|
||||
}
|
||||
|
||||
trackMonthlySubscriptionSucceeded(
|
||||
|
||||
@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'avatar_menu' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when the host bridge is absent', () => {
|
||||
delete window.__comfyDesktop2
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -12,12 +12,29 @@
|
||||
* 3. Check dist/assets/*.js files contain no tracking code
|
||||
*/
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
|
||||
export type PaymentIntentSource =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
| 'subscribe_to_run'
|
||||
| 'subscribe_now_button'
|
||||
| 'upgrade_to_add_credits'
|
||||
| 'settings_billing_panel'
|
||||
| 'avatar_menu_plans'
|
||||
| 'team_members_panel'
|
||||
| 'invite_member_upsell'
|
||||
| 'upload_model_upgrade'
|
||||
| 'team_upgrade_resume'
|
||||
|
||||
export type SubscriptionCheckoutType = 'new' | 'change'
|
||||
export type SubscriptionCheckoutTier = TierKey | 'team'
|
||||
|
||||
/**
|
||||
* Authentication metadata for sign-up tracking
|
||||
*/
|
||||
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
|
||||
|
||||
export interface SubscriptionMetadata {
|
||||
current_tier?: string
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
tier: TierKey
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
checkout_attempt_id?: string
|
||||
billing_op_id?: string
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface EcommerceItemMetadata {
|
||||
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
|
||||
checkout_attempt_id: string
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
value: number
|
||||
currency: string
|
||||
ecommerce: EcommerceMetadata
|
||||
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackAddApiCreditButtonClicked?(): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void
|
||||
|
||||
@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -391,12 +391,13 @@ const showZeroState = computed(
|
||||
)
|
||||
|
||||
function handleSubscribeWorkspace() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleUpgrade() {
|
||||
if (isFreeTierPlan.value) showPricingTable()
|
||||
else showSubscriptionDialog()
|
||||
if (isFreeTierPlan.value)
|
||||
showPricingTable({ reason: 'settings_billing_panel' })
|
||||
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleViewMoreDetails() {
|
||||
|
||||
@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { useEventListener } from '@vueuse/core'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
|
||||
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
|
||||
|
||||
const { onClose, reason, initialPlanMode } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
initialPlanMode?: 'personal' | 'team'
|
||||
}>()
|
||||
|
||||
@@ -152,7 +152,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleTeamSubscribe,
|
||||
handleResubscribe
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
|
||||
// Backspace mirrors the back arrow on the confirm step, but never while an
|
||||
// editable element is focused (let it delete text there).
|
||||
|
||||
@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
|
||||
|
||||
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
|
||||
const mockHandleSuccessClose = vi.fn()
|
||||
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
|
||||
const mockPreviewData = ref<{ transition_type: string } | null>(null)
|
||||
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
|
||||
useSubscriptionCheckout: () => ({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
useSubscriptionCheckout: mockUseSubscriptionCheckout
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
@@ -91,7 +76,7 @@ const SuccessStub = {
|
||||
function renderComponent(
|
||||
props: {
|
||||
onClose?: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
} = {}
|
||||
) {
|
||||
@@ -121,6 +106,23 @@ function renderComponent(
|
||||
describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseSubscriptionCheckout.mockReturnValue({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
mockCheckoutStep.value = 'pricing'
|
||||
mockPreviewData.value = null
|
||||
})
|
||||
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('passes the reason into subscription checkout', () => {
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
'out_of_credits'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows the team workspace header by default', () => {
|
||||
renderComponent()
|
||||
expect(screen.getByText('Team Workspace')).toBeInTheDocument()
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import PricingTableWorkspace from './PricingTableWorkspace.vue'
|
||||
@@ -130,7 +130,7 @@ const {
|
||||
isPersonal = false
|
||||
} = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
}>()
|
||||
|
||||
@@ -154,7 +154,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleResubscribe,
|
||||
handleSuccessClose
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -61,6 +61,9 @@ function onDismiss() {
|
||||
|
||||
function onUpgrade() {
|
||||
dialogStore.closeDialog({ key: 'invite-member-upsell' })
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({
|
||||
planMode: 'team',
|
||||
reason: 'invite_member_upsell'
|
||||
})
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -277,7 +277,7 @@ export function useMembersPanel() {
|
||||
}
|
||||
|
||||
function showTeamPlans() {
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed } from 'vue'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import type { Plan } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
import { findPlanSlug } from './useSubscriptionCheckout'
|
||||
@@ -75,7 +76,9 @@ const {
|
||||
mockPlans,
|
||||
mockResubscribe,
|
||||
mockToastAdd,
|
||||
mockStartOperation
|
||||
mockStartOperation,
|
||||
mockTrackBeginCheckout,
|
||||
mockUserId
|
||||
} = vi.hoisted(() => ({
|
||||
mockSubscribe: vi.fn(),
|
||||
mockPreviewSubscribe: vi.fn(),
|
||||
@@ -84,7 +87,9 @@ const {
|
||||
mockPlans: { value: [] as Plan[] },
|
||||
mockResubscribe: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockStartOperation: vi.fn()
|
||||
mockStartOperation: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
|
||||
useTelemetry: () => ({
|
||||
trackMonthlySubscriptionSucceeded: vi.fn(),
|
||||
trackBeginCheckout: mockTrackBeginCheckout
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
|
||||
describe('useSubscriptionCheckout', () => {
|
||||
let emit: ReturnType<typeof vi.fn>
|
||||
|
||||
async function setup() {
|
||||
async function setup(paymentIntentSource?: PaymentIntentSource) {
|
||||
const { useSubscriptionCheckout } =
|
||||
await import('./useSubscriptionCheckout')
|
||||
return useSubscriptionCheckout(emit as never)
|
||||
return useSubscriptionCheckout(emit as never, paymentIntentSource)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPlans.value = allPlans()
|
||||
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
|
||||
mockUserId.value = 'user-1'
|
||||
emit = vi.fn()
|
||||
})
|
||||
|
||||
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
|
||||
cancelUrl: 'https://platform.comfy.org/payment/failed'
|
||||
})
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-team-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the annual plan slug for the yearly cycle', async () => {
|
||||
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Team payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps team checkout_type as change when the preview request fails', async () => {
|
||||
const checkout = await setup()
|
||||
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
|
||||
await checkout.handleSubscribeTeamClick({
|
||||
stop: {
|
||||
id: 'team_1400',
|
||||
usd: 1400,
|
||||
credits: 295_400,
|
||||
discountedUsd: 1295
|
||||
},
|
||||
billingCycle: 'monthly',
|
||||
isChange: true
|
||||
})
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleTeamSubscribe()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'change',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
})
|
||||
|
||||
it('skips begin_checkout when no user id is available', async () => {
|
||||
mockUserId.value = null
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
mockUserId.value = 'user-1'
|
||||
})
|
||||
|
||||
it('fires begin_checkout carrying the payment intent source', async () => {
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-1',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('opens payment URL when needs_payment_method', async () => {
|
||||
const checkout = await setup()
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type {
|
||||
Plan,
|
||||
PreviewSubscribeResponse,
|
||||
SubscribeResponse
|
||||
} from '@/platform/workspace/api/workspaceApi'
|
||||
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
type CheckoutStep = 'pricing' | 'preview' | 'success'
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
|
||||
interface SelectedTeamCheckout {
|
||||
stop: TeamPlanSelection
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
}
|
||||
|
||||
/**
|
||||
* Which screen the `preview` step shows. Only a change prorates: a team change
|
||||
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
|
||||
@@ -45,9 +55,12 @@ export function findPlanSlug(
|
||||
return plan?.slug ?? null
|
||||
}
|
||||
|
||||
export function useSubscriptionCheckout(emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
}) {
|
||||
export function useSubscriptionCheckout(
|
||||
emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
},
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
) {
|
||||
const { t } = useI18n()
|
||||
const toast = useToast()
|
||||
const {
|
||||
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
|
||||
const isResubscribing = ref(false)
|
||||
const previewData = ref<PreviewSubscribeResponse | null>(null)
|
||||
const selectedTierKey = ref<CheckoutTierKey | null>(null)
|
||||
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
|
||||
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
|
||||
const selectedBillingCycle = ref<BillingCycle>('yearly')
|
||||
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
|
||||
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
|
||||
const selectedTeamStop = computed(
|
||||
() => selectedTeamCheckout.value?.stop ?? null
|
||||
)
|
||||
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
|
||||
|
||||
const previewVariant = computed<PreviewVariant>(() => {
|
||||
if (selectedTeamStop.value) {
|
||||
if (selectedTeamCheckout.value) {
|
||||
return previewData.value ? 'team-change' : 'team-new'
|
||||
}
|
||||
if (previewData.value) {
|
||||
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
|
||||
billingCycle: BillingCycle
|
||||
isChange?: boolean
|
||||
}) {
|
||||
selectedTeamStop.value = payload.stop
|
||||
selectedTeamCheckout.value = {
|
||||
stop: payload.stop,
|
||||
checkoutType: payload.isChange ? 'change' : 'new'
|
||||
}
|
||||
selectedBillingCycle.value = payload.billingCycle
|
||||
selectedTierKey.value = null
|
||||
previewData.value = null
|
||||
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
|
||||
function handleBackToPricing() {
|
||||
checkoutStep.value = 'pricing'
|
||||
previewData.value = null
|
||||
selectedTeamStop.value = null
|
||||
selectedTeamCheckout.value = null
|
||||
}
|
||||
|
||||
function handleSuccessClose() {
|
||||
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleSubscription() {
|
||||
if (!selectedTierKey.value) return
|
||||
const tierKey = selectedTierKey.value
|
||||
if (!tierKey) return
|
||||
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
const checkoutType =
|
||||
previewData.value &&
|
||||
previewData.value.transition_type !== 'new_subscription'
|
||||
? 'change'
|
||||
: 'new'
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getApiPlanSlug(
|
||||
selectedTierKey.value,
|
||||
selectedBillingCycle.value
|
||||
)
|
||||
const planSlug = getApiPlanSlug(tierKey, billingCycle)
|
||||
if (!planSlug) return
|
||||
const response = await subscribe(planSlug, {
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: tierKey,
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleTeamSubscription() {
|
||||
const stop = selectedTeamStop.value
|
||||
if (!stop?.id) {
|
||||
const teamCheckout = selectedTeamCheckout.value
|
||||
if (!teamCheckout?.stop.id) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('subscription.teamPlan.name'),
|
||||
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
|
||||
return
|
||||
}
|
||||
|
||||
const { stop, checkoutType } = teamCheckout
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
|
||||
const planSlug = getTeamPlanSlug(billingCycle)
|
||||
const response = await subscribe(planSlug, {
|
||||
teamCreditStopId: stop.id,
|
||||
billingCycle: selectedBillingCycle.value,
|
||||
billingCycle,
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
|
||||
|
||||
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
|
||||
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingBalanceResponse,
|
||||
BillingStatusResponse,
|
||||
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
subscriptionDialog.show()
|
||||
subscriptionDialog.show({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
subscriptionDialog.show()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
subscriptionDialog.show(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutTier,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
interface TrackWorkspaceCheckoutStartedOptions {
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
billingOpId: string
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export function trackWorkspaceCheckoutStarted({
|
||||
tier,
|
||||
cycle,
|
||||
checkoutType,
|
||||
billingOpId,
|
||||
paymentIntentSource
|
||||
}: TrackWorkspaceCheckoutStartedOptions) {
|
||||
const { userId } = useAuthStore()
|
||||
if (!userId) return
|
||||
|
||||
useTelemetry()?.trackBeginCheckout({
|
||||
user_id: userId,
|
||||
tier,
|
||||
cycle,
|
||||
checkout_type: checkoutType,
|
||||
billing_op_id: billingOpId,
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {})
|
||||
})
|
||||
}
|
||||
57
src/schemas/nodeDefSchema.test.ts
Normal file
57
src/schemas/nodeDefSchema.test.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
getComboSpecComboOptions,
|
||||
getInputSpecType,
|
||||
isComboInputSpec,
|
||||
isComboInputSpecV1,
|
||||
isComboInputSpecV2,
|
||||
isFloatInputSpec,
|
||||
isIntInputSpec,
|
||||
isMediaUploadComboInput
|
||||
} from './nodeDefSchema'
|
||||
import type {
|
||||
ComboInputSpec,
|
||||
ComboInputSpecV2,
|
||||
InputSpec
|
||||
} from './nodeDefSchema'
|
||||
|
||||
describe('node definition schema helpers', () => {
|
||||
it('identifies input spec variants', () => {
|
||||
const intSpec: InputSpec = ['INT', {}]
|
||||
const floatSpec: InputSpec = ['FLOAT', {}]
|
||||
const comboV1: ComboInputSpec = [['a', 'b'], {}]
|
||||
const comboV2: ComboInputSpecV2 = ['COMBO', { options: ['a', 'b'] }]
|
||||
|
||||
expect(isIntInputSpec(intSpec)).toBe(true)
|
||||
expect(isFloatInputSpec(floatSpec)).toBe(true)
|
||||
expect(isComboInputSpecV1(comboV1)).toBe(true)
|
||||
expect(isComboInputSpecV2(comboV2)).toBe(true)
|
||||
expect(isComboInputSpec(comboV1)).toBe(true)
|
||||
expect(isComboInputSpec(comboV2)).toBe(true)
|
||||
expect(getInputSpecType(comboV1)).toBe('COMBO')
|
||||
expect(getInputSpecType(intSpec)).toBe('INT')
|
||||
})
|
||||
|
||||
it('reads combo options from legacy and v2 combo specs', () => {
|
||||
expect(getComboSpecComboOptions([['a', 1], {}])).toEqual(['a', 1])
|
||||
expect(
|
||||
getComboSpecComboOptions(['COMBO', { options: ['x', 'y'] }])
|
||||
).toEqual(['x', 'y'])
|
||||
expect(getComboSpecComboOptions(['COMBO', {}])).toEqual([])
|
||||
})
|
||||
|
||||
it('detects media upload combo inputs', () => {
|
||||
expect(isMediaUploadComboInput([['a'], { image_upload: true }])).toBe(true)
|
||||
expect(
|
||||
isMediaUploadComboInput(['COMBO', { animated_image_upload: true }])
|
||||
).toBe(true)
|
||||
expect(isMediaUploadComboInput(['COMBO', { video_upload: true }])).toBe(
|
||||
true
|
||||
)
|
||||
expect(isMediaUploadComboInput(['STRING', { image_upload: true }])).toBe(
|
||||
false
|
||||
)
|
||||
expect(isMediaUploadComboInput(['COMBO', undefined])).toBe(false)
|
||||
})
|
||||
})
|
||||
149
src/scripts/api.cloud.test.ts
Normal file
149
src/scripts/api.cloud.test.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
fetchWithUnifiedRemint,
|
||||
shouldRemintCloudRequest
|
||||
} from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockAuthStore } = vi.hoisted(() => ({
|
||||
mockAuthStore: {
|
||||
isInitialized: true,
|
||||
getAuthHeader: vi.fn(),
|
||||
getAuthToken: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({ isCloud: true }))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: vi.fn(() => mockAuthStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
const { ComfyApi } = await import('./api')
|
||||
|
||||
describe('ComfyApi cloud mode', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
mockAuthStore.isInitialized = true
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue(null)
|
||||
mockAuthStore.getAuthToken.mockResolvedValue(null)
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(false)
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
new Response(JSON.stringify({ ok: true }), {
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
})
|
||||
)
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
it('adds cloud auth headers and enables unified retry for authenticated requests', async () => {
|
||||
mockAuthStore.getAuthHeader.mockResolvedValue({
|
||||
Authorization: 'Bearer firebase-token'
|
||||
})
|
||||
vi.mocked(shouldRemintCloudRequest).mockResolvedValue(true)
|
||||
const api = new ComfyApi()
|
||||
api.user = 'cloud-user'
|
||||
|
||||
await api.fetchApi('/queue')
|
||||
|
||||
expect(api.api_base).toBe('')
|
||||
expect(fetchWithUnifiedRemint).toHaveBeenCalledWith(
|
||||
'/api/queue',
|
||||
expect.objectContaining({
|
||||
cache: 'no-cache',
|
||||
headers: {
|
||||
Authorization: 'Bearer firebase-token',
|
||||
'Comfy-User': 'cloud-user'
|
||||
}
|
||||
}),
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('continues cloud fetches when auth header lookup fails', async () => {
|
||||
mockAuthStore.getAuthHeader.mockRejectedValue(new Error('auth unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
await api.fetchApi('/history', {
|
||||
headers: [['X-Test', '1']]
|
||||
})
|
||||
|
||||
const [, options, retryOn401] = vi.mocked(fetchWithUnifiedRemint).mock
|
||||
.calls[0]
|
||||
expect(options.headers).toEqual([
|
||||
['X-Test', '1'],
|
||||
['Comfy-User', '']
|
||||
])
|
||||
expect(retryOn401).toBe(false)
|
||||
expect(shouldRemintCloudRequest).not.toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Failed to get auth header:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
|
||||
it('adds the cloud auth token to websocket URLs', async () => {
|
||||
mockAuthStore.getAuthToken.mockResolvedValue('socket-token')
|
||||
window.name = 'client-1'
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).toContain('clientId=client-1')
|
||||
expect(socket.url).toContain('token=socket-token')
|
||||
})
|
||||
|
||||
it('opens a cloud websocket without a token when token lookup fails', async () => {
|
||||
mockAuthStore.getAuthToken.mockRejectedValue(new Error('token unavailable'))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const api = new ComfyApi()
|
||||
|
||||
api.init()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(FakeWebSocket.instances).toHaveLength(1)
|
||||
})
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
expect(socket.url).not.toContain('token=')
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Could not get auth token for WebSocket connection:',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
})
|
||||
460
src/scripts/api.core.test.ts
Normal file
460
src/scripts/api.core.test.ts
Normal file
@@ -0,0 +1,460 @@
|
||||
import axios from 'axios'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import type {
|
||||
ComfyApiWorkflow,
|
||||
ComfyWorkflowJSON
|
||||
} from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { PromptResponse } from '@/schemas/apiSchema'
|
||||
import { api as sharedApi, ComfyApi, PromptExecutionError } from '@/scripts/api'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
|
||||
const fetchJobs = vi.hoisted(() => ({
|
||||
fetchHistory: vi.fn(),
|
||||
fetchJobDetail: vi.fn(),
|
||||
fetchQueue: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('axios')
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => fetchJobs)
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function promptData(): {
|
||||
output: ComfyApiWorkflow
|
||||
workflow: ComfyWorkflowJSON
|
||||
} {
|
||||
return {
|
||||
output: fromPartial<ComfyApiWorkflow>({
|
||||
1: {
|
||||
inputs: {},
|
||||
class_type: 'KSampler',
|
||||
_meta: { title: 'KSampler' }
|
||||
}
|
||||
}),
|
||||
workflow: fromPartial<ComfyWorkflowJSON>({
|
||||
version: 0.4,
|
||||
nodes: [],
|
||||
links: []
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function requestBody(fetchApi: ReturnType<typeof vi.spyOn>, call = 0) {
|
||||
const init = fetchApi.mock.calls[call][1]
|
||||
return JSON.parse(String(init?.body)) as Record<string, unknown>
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string and node-specific prompt errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: 'invalid prompt',
|
||||
node_errors: {
|
||||
7: {
|
||||
class_type: 'KSampler',
|
||||
dependent_outputs: [],
|
||||
errors: [{ message: 'bad seed', details: 'seed must be numeric' }]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response, 400).toString()).toBe(
|
||||
'invalid prompt\nKSampler:\n - bad seed: seed must be numeric'
|
||||
)
|
||||
})
|
||||
|
||||
it('formats structured prompt errors without node errors', () => {
|
||||
const response = fromPartial<PromptResponse>({
|
||||
error: {
|
||||
type: 'prompt_outputs_failed_validation',
|
||||
message: 'Validation failed',
|
||||
details: 'missing node'
|
||||
}
|
||||
})
|
||||
|
||||
expect(new PromptExecutionError(response).toString()).toBe(
|
||||
'Validation failed: missing node'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queuePrompt', () => {
|
||||
it('sends front queue requests with auth and execution options', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'auth-token'
|
||||
api.apiKey = 'api-key'
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(-1, promptData(), {
|
||||
partialExecutionTargets: ['9:10' as NodeExecutionId],
|
||||
previewMethod: 'auto'
|
||||
})
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith('/prompt', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: fetchApi.mock.calls[0][1]?.body
|
||||
})
|
||||
expect(typeof fetchApi.mock.calls[0][1]?.body).toBe('string')
|
||||
expect(requestBody(fetchApi)).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
front: true,
|
||||
partial_execution_targets: ['9:10'],
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'auth-token',
|
||||
api_key_comfy_org: 'api-key',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'auto'
|
||||
}
|
||||
})
|
||||
expect(requestBody(fetchApi)).not.toHaveProperty('number')
|
||||
})
|
||||
|
||||
it('omits default-only queue options and sets explicit queue numbers', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockImplementation(() =>
|
||||
Promise.resolve(jsonResponse({ prompt_id: 'queued' }))
|
||||
)
|
||||
|
||||
await api.queuePrompt(0, promptData(), { previewMethod: 'default' })
|
||||
await api.queuePrompt(4, promptData())
|
||||
|
||||
expect(requestBody(fetchApi, 0)).toMatchObject({ client_id: '' })
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('front')
|
||||
expect(requestBody(fetchApi, 0)).not.toHaveProperty('number')
|
||||
expect(
|
||||
requestBody(fetchApi, 0).extra_data as Record<string, unknown>
|
||||
).not.toHaveProperty('preview_method')
|
||||
expect(requestBody(fetchApi, 1)).toMatchObject({ number: 4 })
|
||||
})
|
||||
|
||||
it('throws parsed prompt errors from non-200 responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
jsonResponse(
|
||||
{
|
||||
error: {
|
||||
type: 'server_error',
|
||||
message: 'Server rejected prompt',
|
||||
details: 'bad output'
|
||||
}
|
||||
},
|
||||
{ status: 400, statusText: 'Bad Request' }
|
||||
)
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toThrow(
|
||||
'Prompt execution failed'
|
||||
)
|
||||
})
|
||||
|
||||
it('wraps non-json prompt errors with status details', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(api.queuePrompt(0, promptData())).rejects.toMatchObject({
|
||||
status: 500,
|
||||
response: {
|
||||
error: {
|
||||
message: '500 Server Error',
|
||||
details: 'backend exploded'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi read helpers', () => {
|
||||
it('returns localized templates, default templates, and empty non-json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'localized' }]
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/html' },
|
||||
data: '<html></html>'
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'localized' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
expect(vi.mocked(axios.get).mock.calls[0][0]).toContain(
|
||||
'/templates/index.fr.json'
|
||||
)
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('falls back from missing localized templates to the default index', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.mocked(axios.get)
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: [{ name: 'default' }]
|
||||
})
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('ja')).resolves.toEqual([
|
||||
{ name: 'default' }
|
||||
])
|
||||
expect(vi.mocked(axios.get).mock.calls[1][0]).toContain(
|
||||
'/templates/index.json'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty model lists for 404s and filters internal folders', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse([
|
||||
{ name: 'checkpoints' },
|
||||
{ name: 'configs' },
|
||||
{ name: 'custom_nodes' }
|
||||
])
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(jsonResponse(['model.safetensors']))
|
||||
|
||||
await expect(api.getModelFolders()).resolves.toEqual([])
|
||||
await expect(api.getModelFolders()).resolves.toEqual([
|
||||
{ name: 'checkpoints' }
|
||||
])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([])
|
||||
await expect(api.getModels('checkpoints')).resolves.toEqual([
|
||||
'model.safetensors'
|
||||
])
|
||||
expect(fetchApi).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
|
||||
it('handles model metadata text responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(new Response(''))
|
||||
.mockResolvedValueOnce(new Response('{"format":"safetensors"}'))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('not json', { status: 200, statusText: 'OK' })
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toEqual({ format: 'safetensors' })
|
||||
await expect(
|
||||
api.viewMetadata('checkpoints', 'a.safetensors')
|
||||
).resolves.toBe(null)
|
||||
})
|
||||
|
||||
it('gets fuse options only from json responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'application/json' },
|
||||
data: { keys: ['name'] }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
headers: { 'content-type': 'text/plain' },
|
||||
data: 'nope'
|
||||
})
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
vi.mocked(axios.get).mockRejectedValueOnce(new Error('missing'))
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi queue and data helpers', () => {
|
||||
it('routes item collection requests to queue or history', async () => {
|
||||
const api = new ComfyApi()
|
||||
const queue = vi.spyOn(api, 'getQueue').mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
const historyItem = fromPartial<JobListItem>({
|
||||
id: 'history-1',
|
||||
status: 'completed',
|
||||
create_time: 1,
|
||||
priority: 0
|
||||
})
|
||||
const history = vi.spyOn(api, 'getHistory').mockResolvedValue([historyItem])
|
||||
|
||||
await expect(api.getItems('queue')).resolves.toEqual({
|
||||
Running: [],
|
||||
Pending: []
|
||||
})
|
||||
await expect(api.getItems('history')).resolves.toEqual([historyItem])
|
||||
expect(queue).toHaveBeenCalledOnce()
|
||||
expect(history).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('returns queue fallbacks unless errors are requested', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchQueue.mockRejectedValue(new Error('network'))
|
||||
|
||||
await expect(api.getQueue()).resolves.toEqual({ Running: [], Pending: [] })
|
||||
await expect(api.getQueue({ throwOnError: true })).rejects.toThrow(
|
||||
'network'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty history when fetchHistory fails', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
fetchJobs.fetchHistory.mockRejectedValue(new Error('history down'))
|
||||
|
||||
await expect(api.getHistory()).resolves.toEqual([])
|
||||
})
|
||||
|
||||
it('posts item mutations with and without request bodies', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi.spyOn(api, 'fetchApi').mockResolvedValue(new Response())
|
||||
|
||||
await api.deleteItem('history', 'job-1')
|
||||
await api.clearItems('queue')
|
||||
await api.interrupt(null)
|
||||
await api.interrupt('running-1')
|
||||
|
||||
expect(fetchApi.mock.calls.map((call) => call[0])).toEqual([
|
||||
'/history',
|
||||
'/queue',
|
||||
'/interrupt',
|
||||
'/interrupt'
|
||||
])
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(
|
||||
JSON.stringify({ delete: ['job-1'] })
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(
|
||||
JSON.stringify({ clear: true })
|
||||
)
|
||||
expect(fetchApi.mock.calls[2][1]?.body).toBeUndefined()
|
||||
expect(fetchApi.mock.calls[3][1]?.body).toBe(
|
||||
JSON.stringify({ prompt_id: 'running-1' })
|
||||
)
|
||||
})
|
||||
|
||||
it('throws unauthorized settings responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi').mockResolvedValue(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
|
||||
await expect(api.getSettings()).rejects.toThrow('Unauthorized')
|
||||
})
|
||||
|
||||
it('stores user data with default and raw-body options', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(new Response('', { status: 200 }))
|
||||
const raw = new Blob(['raw'])
|
||||
|
||||
await api.storeUserData('a/b.json', { ok: true })
|
||||
await api.storeUserData('raw.bin', raw, {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
|
||||
expect(fetchApi.mock.calls[0][0]).toBe(
|
||||
'/userdata/a%2Fb.json?overwrite=true&full_info=false'
|
||||
)
|
||||
expect(fetchApi.mock.calls[0][1]?.body).toBe(JSON.stringify({ ok: true }))
|
||||
expect(fetchApi.mock.calls[1][0]).toBe(
|
||||
'/userdata/raw.bin?overwrite=false&full_info=true'
|
||||
)
|
||||
expect(fetchApi.mock.calls[1][1]?.body).toBe(raw)
|
||||
})
|
||||
|
||||
it('honors storeUserData throwOnError', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
|
||||
await expect(api.storeUserData('bad.json', {})).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json': 500 Server Error"
|
||||
)
|
||||
await expect(
|
||||
api.storeUserData('bad.json', {}, { throwOnError: false })
|
||||
).resolves.toHaveProperty('status', 500)
|
||||
})
|
||||
|
||||
it('lists full user data info by status', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse([{ path: 'x' }]))
|
||||
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('models/')).rejects.toThrow(
|
||||
"Error getting user data list 'models': 500 Server Error"
|
||||
)
|
||||
await expect(api.listUserDataFullInfo('models/')).resolves.toEqual([
|
||||
{ path: 'x' }
|
||||
])
|
||||
})
|
||||
|
||||
it('loads global subgraph records and deferred data', async () => {
|
||||
const fetchApi = vi
|
||||
.spyOn(sharedApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
ready: { name: 'Ready', info: { node_pack: 'core' }, data: '{}' },
|
||||
lazy: { name: 'Lazy', info: { node_pack: 'core' } }
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: '{"lazy":true}' }))
|
||||
|
||||
await expect(sharedApi.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await sharedApi.getGlobalSubgraphs()
|
||||
|
||||
expect(subgraphs.ready.data).toBe('{}')
|
||||
expect(subgraphs.lazy.data).toBeInstanceOf(Promise)
|
||||
await expect(subgraphs.lazy.data).resolves.toBe('{"lazy":true}')
|
||||
expect(fetchApi).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
829
src/scripts/api.test.ts
Normal file
829
src/scripts/api.test.ts
Normal file
@@ -0,0 +1,829 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import axios from 'axios'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
api as singletonApi,
|
||||
ComfyApi,
|
||||
PromptExecutionError,
|
||||
UnauthorizedError
|
||||
} from '@/scripts/api'
|
||||
import type { ComfyApiWorkflow } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { NodeExecutionId } from '@/types/nodeIdentification'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
|
||||
const { mockToastStore } = vi.hoisted(() => ({
|
||||
mockToastStore: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/auth/unified/remintRetry', () => ({
|
||||
fetchWithUnifiedRemint: vi.fn(),
|
||||
shouldRemintCloudRequest: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => mockToastStore)
|
||||
}))
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: {
|
||||
get: vi.fn(),
|
||||
patch: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
class FakeWebSocket extends EventTarget {
|
||||
static instances: FakeWebSocket[] = []
|
||||
|
||||
binaryType = ''
|
||||
sent: string[] = []
|
||||
|
||||
constructor(readonly url: string) {
|
||||
super()
|
||||
FakeWebSocket.instances.push(this)
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data)
|
||||
}
|
||||
|
||||
close() {
|
||||
this.dispatchEvent(new Event('close'))
|
||||
}
|
||||
}
|
||||
|
||||
function jsonResponse(data: unknown, init: ResponseInit = {}) {
|
||||
return new Response(JSON.stringify(data), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
...init
|
||||
})
|
||||
}
|
||||
|
||||
function createWorkflow() {
|
||||
return {
|
||||
last_node_id: 0,
|
||||
last_link_id: 0,
|
||||
nodes: [],
|
||||
links: [],
|
||||
groups: [],
|
||||
config: {},
|
||||
extra: {},
|
||||
version: 0.4
|
||||
}
|
||||
}
|
||||
|
||||
function binaryMessage(type: number, payload: Uint8Array) {
|
||||
const bytes = new Uint8Array(4 + payload.length)
|
||||
new DataView(bytes.buffer).setUint32(0, type)
|
||||
bytes.set(payload, 4)
|
||||
return bytes.buffer
|
||||
}
|
||||
|
||||
function uint32(value: number) {
|
||||
const bytes = new Uint8Array(4)
|
||||
new DataView(bytes.buffer).setUint32(0, value)
|
||||
return bytes
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
describe('PromptExecutionError', () => {
|
||||
it('formats string, structured, and node-level errors', () => {
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: 'Queue rejected',
|
||||
node_errors: {}
|
||||
}).toString()
|
||||
).toBe('Queue rejected')
|
||||
|
||||
expect(
|
||||
new PromptExecutionError({
|
||||
error: {
|
||||
type: 'invalid_prompt',
|
||||
message: 'Invalid prompt',
|
||||
details: 'missing input'
|
||||
},
|
||||
node_errors: {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
dependent_outputs: ['1'],
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Required input',
|
||||
details: 'source'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}).toString()
|
||||
).toContain('Invalid prompt: missing input\nPreviewAny:')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
FakeWebSocket.instances = []
|
||||
window.name = ''
|
||||
sessionStorage.clear()
|
||||
vi.stubGlobal('WebSocket', FakeWebSocket as unknown as typeof WebSocket)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('builds API, internal, file, and fetch URLs with user headers', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.user = 'reviewer'
|
||||
vi.mocked(fetchWithUnifiedRemint).mockResolvedValue(
|
||||
jsonResponse({ ok: true })
|
||||
)
|
||||
|
||||
await api.fetchApi('/queue', {
|
||||
headers: new Headers([['X-Test', '1']])
|
||||
})
|
||||
|
||||
expect(api.apiURL('/api/custom')).toBe(`${api.api_base}/api/custom`)
|
||||
expect(api.apiURL('/queue')).toBe(`${api.api_base}/api/queue`)
|
||||
expect(api.internalURL('/logs')).toBe(`${api.api_base}/internal/logs`)
|
||||
expect(api.fileURL('/view')).toBe(`${api.api_base}/view`)
|
||||
const [, options] = vi.mocked(fetchWithUnifiedRemint).mock.calls[0]
|
||||
expect(options.headers).toBeInstanceOf(Headers)
|
||||
expect((options.headers as Headers).get('Comfy-User')).toBe('reviewer')
|
||||
expect((options.headers as Headers).get('X-Test')).toBe('1')
|
||||
})
|
||||
|
||||
it('guards event listeners and still allows removing them', async () => {
|
||||
const api = new ComfyApi()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const listener = vi.fn()
|
||||
const throwingListener = vi.fn(() => {
|
||||
throw new Error('listener failed')
|
||||
})
|
||||
const asyncListener = vi.fn(() => Promise.reject(new Error('async failed')))
|
||||
const objectListener = { handleEvent: vi.fn() }
|
||||
|
||||
api.addEventListener('status', null)
|
||||
api.removeEventListener('status', null)
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', throwingListener)
|
||||
api.addEventListener('status', asyncListener)
|
||||
api.addEventListener('status', fromAny(objectListener))
|
||||
|
||||
api.dispatchCustomEvent('status', { exec_info: { queue_remaining: 1 } })
|
||||
await Promise.resolve()
|
||||
|
||||
expect(listener).toHaveBeenCalled()
|
||||
expect(throwingListener).toHaveBeenCalled()
|
||||
expect(asyncListener).toHaveBeenCalled()
|
||||
expect(objectListener.handleEvent).toHaveBeenCalled()
|
||||
expect(warn).toHaveBeenCalledTimes(2)
|
||||
|
||||
api.removeEventListener('status', listener)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('reuses guarded listener wrappers and ignores unknown removals', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
const neverRegistered = vi.fn()
|
||||
|
||||
api.addEventListener('status', listener)
|
||||
api.addEventListener('status', listener)
|
||||
api.removeEventListener('status', neverRegistered)
|
||||
api.dispatchCustomEvent('status', null)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
expect(neverRegistered).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('supports guarded custom event listeners', () => {
|
||||
const api = new ComfyApi()
|
||||
const listener = vi.fn()
|
||||
|
||||
api.addCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: true } })
|
||||
)
|
||||
api.removeCustomEventListener('custom-node-event', listener)
|
||||
;(api as EventTarget).dispatchEvent(
|
||||
new CustomEvent('custom-node-event', { detail: { ok: false } })
|
||||
)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('routes websocket JSON messages and custom registered messages', () => {
|
||||
window.name = 'existing-client'
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
const featureFlags = vi.fn()
|
||||
const custom = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'log').mockImplementation(() => undefined)
|
||||
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.addEventListener('feature_flags', featureFlags)
|
||||
api.addCustomEventListener('custom-message', custom)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(socket.url).toContain('clientId=existing-client')
|
||||
expect(JSON.parse(socket.sent[0])).toMatchObject({
|
||||
type: 'feature_flags'
|
||||
})
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: {
|
||||
sid: 'fresh-client',
|
||||
status: { exec_info: { queue_remaining: 2 } }
|
||||
}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: '12' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'feature_flags',
|
||||
data: { supports_progress_text_metadata: true }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'custom-message',
|
||||
data: { from: 'extension' }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'unknown-message',
|
||||
data: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(api.clientId).toBe('fresh-client')
|
||||
expect(window.name).toBe('fresh-client')
|
||||
expect(sessionStorage.getItem('clientId')).toBe('fresh-client')
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 2 } }
|
||||
})
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: '12' })
|
||||
)
|
||||
expect(featureFlags).toHaveBeenCalled()
|
||||
expect(api.serverSupportsFeature('supports_progress_text_metadata')).toBe(
|
||||
true
|
||||
)
|
||||
expect(custom).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: { from: 'extension' } })
|
||||
)
|
||||
expect(api.reportedUnknownMessageTypes.has('unknown-message')).toBe(true)
|
||||
expect(warn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('polls status when the initial websocket connection fails', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({ exec_info: { queue_remaining: 4 } })
|
||||
)
|
||||
.mockRejectedValueOnce(new Error('poll failed'))
|
||||
api.addEventListener('status', status)
|
||||
|
||||
api.init()
|
||||
FakeWebSocket.instances[0].dispatchEvent(new Event('error'))
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
await vi.advanceTimersByTimeAsync(1000)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { exec_info: { queue_remaining: 4 } }
|
||||
})
|
||||
)
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('emits reconnect lifecycle events after an opened websocket closes', async () => {
|
||||
vi.useFakeTimers()
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const reconnecting = vi.fn()
|
||||
const reconnected = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('reconnecting', reconnecting)
|
||||
api.addEventListener('reconnected', reconnected)
|
||||
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
socket.dispatchEvent(new Event('open'))
|
||||
socket.close()
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(reconnecting).toHaveBeenCalledOnce()
|
||||
|
||||
await vi.advanceTimersByTimeAsync(300)
|
||||
const reconnectSocket = FakeWebSocket.instances[1]
|
||||
reconnectSocket.dispatchEvent(new Event('open'))
|
||||
|
||||
expect(reconnected).toHaveBeenCalledOnce()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('routes websocket variants without session ids and display-node fallbacks', () => {
|
||||
const api = new ComfyApi()
|
||||
const status = vi.fn()
|
||||
const executing = vi.fn()
|
||||
api.addEventListener('status', status)
|
||||
api.addEventListener('executing', executing)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'status',
|
||||
data: { status: undefined }
|
||||
})
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: JSON.stringify({
|
||||
type: 'executing',
|
||||
data: { node: 'real', display_node: 'display' }
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
expect(status).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: null })
|
||||
)
|
||||
expect(executing).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'display' })
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary preview and progress websocket messages', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const previewWithMetadata = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const encoder = new TextEncoder()
|
||||
api.serverFeatureFlags.value = {
|
||||
supports_progress_text_metadata: true
|
||||
}
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('b_preview_with_metadata', previewWithMetadata)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(2), new Uint8Array([1, 2])))
|
||||
})
|
||||
)
|
||||
|
||||
const promptId = encoder.encode('prompt-1')
|
||||
const nodeId = encoder.encode('7')
|
||||
const progressPayload = concatBytes(
|
||||
uint32(promptId.length),
|
||||
promptId,
|
||||
uint32(nodeId.length),
|
||||
nodeId,
|
||||
encoder.encode('loading')
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, progressPayload)
|
||||
})
|
||||
)
|
||||
|
||||
const metadata = encoder.encode(
|
||||
JSON.stringify({
|
||||
image_type: 'image/webp',
|
||||
node_id: '7',
|
||||
display_node_id: '7',
|
||||
parent_node_id: '4',
|
||||
real_node_id: '7',
|
||||
prompt_id: 'prompt-1'
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
4,
|
||||
concatBytes(uint32(metadata.length), metadata, new Uint8Array([9]))
|
||||
)
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/png' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: {
|
||||
nodeId: '7',
|
||||
text: 'loading',
|
||||
prompt_id: 'prompt-1'
|
||||
}
|
||||
})
|
||||
)
|
||||
expect(previewWithMetadata).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({
|
||||
nodeId: '7',
|
||||
parentNodeId: '4',
|
||||
jobId: 'prompt-1'
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('routes binary jpeg/default previews and malformed binary messages defensively', () => {
|
||||
const api = new ComfyApi()
|
||||
const preview = vi.fn()
|
||||
const progressText = vi.fn()
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
const encoder = new TextEncoder()
|
||||
api.addEventListener('b_preview', preview)
|
||||
api.addEventListener('progress_text', progressText)
|
||||
api.init()
|
||||
const socket = FakeWebSocket.instances[0]
|
||||
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(1), new Uint8Array([1])))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(1, concatBytes(uint32(99), new Uint8Array([2])))
|
||||
})
|
||||
)
|
||||
|
||||
const nodeId = encoder.encode('node')
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(
|
||||
3,
|
||||
concatBytes(uint32(nodeId.length), nodeId, encoder.encode('ready'))
|
||||
)
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(3, new Uint8Array([1]))
|
||||
})
|
||||
)
|
||||
socket.dispatchEvent(
|
||||
new MessageEvent('message', {
|
||||
data: binaryMessage(99, new Uint8Array())
|
||||
})
|
||||
)
|
||||
|
||||
expect(preview).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: expect.objectContaining({ type: 'image/jpeg' })
|
||||
})
|
||||
)
|
||||
expect(progressText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
detail: { nodeId: 'node', text: 'ready' }
|
||||
})
|
||||
)
|
||||
expect(warn).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('serializes prompt queue options and surfaces non-200 errors', async () => {
|
||||
const api = new ComfyApi()
|
||||
api.clientId = 'client-1'
|
||||
api.authToken = 'token-1'
|
||||
api.apiKey = 'key-1'
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ prompt_id: 'queued' }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('backend exploded', {
|
||||
status: 500,
|
||||
statusText: 'Server Error'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(
|
||||
-1,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
partialExecutionTargets: ['7' as NodeExecutionId],
|
||||
previewMethod: 'latent2rgb'
|
||||
}
|
||||
)
|
||||
).resolves.toEqual({ prompt_id: 'queued' })
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body).toMatchObject({
|
||||
client_id: 'client-1',
|
||||
prompt,
|
||||
partial_execution_targets: ['7'],
|
||||
front: true,
|
||||
extra_data: {
|
||||
auth_token_comfy_org: 'token-1',
|
||||
api_key_comfy_org: 'key-1',
|
||||
comfy_usage_source: 'comfyui-frontend',
|
||||
preview_method: 'latent2rgb'
|
||||
}
|
||||
})
|
||||
expect(body.number).toBeUndefined()
|
||||
|
||||
await expect(
|
||||
api.queuePrompt(3, { output: prompt, workflow: createWorkflow() })
|
||||
).rejects.toMatchObject({ status: 500 })
|
||||
})
|
||||
|
||||
it('omits queue position and default preview method for normal queueing', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValue(jsonResponse({ prompt_id: 'queued' }))
|
||||
|
||||
await api.queuePrompt(
|
||||
0,
|
||||
{ output: prompt, workflow: createWorkflow() },
|
||||
{
|
||||
previewMethod: 'default'
|
||||
}
|
||||
)
|
||||
|
||||
const body = JSON.parse(fetchApi.mock.calls[0][1]?.body as string)
|
||||
expect(body.front).toBeUndefined()
|
||||
expect(body.number).toBeUndefined()
|
||||
expect(body.extra_data.preview_method).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles shareable assets, settings, userdata, subgraphs, and memory APIs', async () => {
|
||||
const api = new ComfyApi()
|
||||
const prompt: ComfyApiWorkflow = {
|
||||
1: {
|
||||
class_type: 'PreviewAny',
|
||||
inputs: {},
|
||||
_meta: { title: 'PreviewAny' }
|
||||
}
|
||||
}
|
||||
vi.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ assets: [] }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
new Response('', { status: 401, statusText: 'Unauthorized' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 204 }))
|
||||
.mockResolvedValueOnce(jsonResponse([], { status: 404 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 500, statusText: 'Server Error' })
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'subgraph-data' }))
|
||||
.mockResolvedValueOnce(jsonResponse({ missing: {} }))
|
||||
.mockResolvedValueOnce(jsonResponse({ one: { data: 'inline' } }))
|
||||
|
||||
await expect(
|
||||
api.getShareableAssets(prompt, { owned: false })
|
||||
).resolves.toEqual({ assets: [] })
|
||||
await expect(api.getShareableAssets(prompt)).rejects.toThrow(
|
||||
'Failed to fetch shareable assets'
|
||||
)
|
||||
await expect(api.getSettings()).rejects.toBeInstanceOf(UnauthorizedError)
|
||||
await expect(
|
||||
api.storeUserData('plain.txt', 'raw', {
|
||||
overwrite: false,
|
||||
stringify: false,
|
||||
throwOnError: false,
|
||||
full_info: true
|
||||
})
|
||||
).resolves.toHaveProperty('status', 204)
|
||||
await expect(api.listUserDataFullInfo('/missing/')).resolves.toEqual([])
|
||||
await expect(api.listUserDataFullInfo('/broken/')).rejects.toThrow(
|
||||
"Error getting user data list '/broken'"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('one')).resolves.toBe(
|
||||
'subgraph-data'
|
||||
)
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Global subgraph 'missing' returned empty data"
|
||||
)
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({
|
||||
one: { data: 'inline' }
|
||||
})
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
await api.freeMemory({ freeExecutionCache: false })
|
||||
await api.freeMemory({ freeExecutionCache: true })
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary: 'Models and Execution Cache have been cleared.'
|
||||
})
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ summary: 'Models have been unloaded.' })
|
||||
)
|
||||
expect(mockToastStore.add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
summary:
|
||||
'Unloading of models failed. Installed ComfyUI may be an outdated version.'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects non-success global subgraph data responses', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi').mockResolvedValueOnce(
|
||||
jsonResponse({}, { status: 404, statusText: 'Not Found' })
|
||||
)
|
||||
|
||||
await expect(api.getGlobalSubgraphData('missing')).rejects.toThrow(
|
||||
"Failed to fetch global subgraph 'missing': 404 Not Found"
|
||||
)
|
||||
})
|
||||
|
||||
it('handles successful settings and userdata helper request shapes', async () => {
|
||||
const api = new ComfyApi()
|
||||
const fetchApi = vi
|
||||
.spyOn(api, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({ theme: 'dark' }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 200 }))
|
||||
|
||||
await expect(api.getSettings()).resolves.toEqual({ theme: 'dark' })
|
||||
await expect(api.storeUserData('bad.json', { a: 1 })).rejects.toThrow(
|
||||
"Error storing user data file 'bad.json'"
|
||||
)
|
||||
await api.moveUserData('old/path.json', 'new path.json')
|
||||
await api.deleteUserData('old/path.json')
|
||||
|
||||
expect(fetchApi.mock.calls[2]).toEqual([
|
||||
'/userdata/old%2Fpath.json/move/new%20path.json?overwrite=false',
|
||||
{ method: 'POST' }
|
||||
])
|
||||
expect(fetchApi.mock.calls[3]).toEqual([
|
||||
'/userdata/old%2Fpath.json',
|
||||
{ method: 'DELETE' }
|
||||
])
|
||||
})
|
||||
|
||||
it('handles global subgraph fallbacks and log endpoints', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.spyOn(singletonApi, 'fetchApi')
|
||||
.mockResolvedValueOnce(jsonResponse({}, { status: 500 }))
|
||||
.mockResolvedValueOnce(
|
||||
jsonResponse({
|
||||
missing: {
|
||||
name: 'Missing data',
|
||||
info: { node_pack: 'core' }
|
||||
}
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(jsonResponse({ data: 'lazy-data' }))
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({ data: 'log text', headers: {} })
|
||||
.mockResolvedValueOnce({ data: { logs: [] }, headers: {} })
|
||||
.mockRejectedValueOnce(new Error('no folders'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { checkpoints: ['/models'] },
|
||||
headers: {}
|
||||
})
|
||||
vi.mocked(axios.patch).mockResolvedValue({ data: undefined })
|
||||
|
||||
await expect(api.getGlobalSubgraphs()).resolves.toEqual({})
|
||||
const subgraphs = await api.getGlobalSubgraphs()
|
||||
await expect(subgraphs.missing.data).resolves.toBe('lazy-data')
|
||||
await expect(api.getLogs()).resolves.toBe('log text')
|
||||
await expect(api.getRawLogs()).resolves.toEqual({ logs: [] })
|
||||
await api.subscribeLogs(true)
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({})
|
||||
await expect(api.getFolderPaths()).resolves.toEqual({
|
||||
checkpoints: ['/models']
|
||||
})
|
||||
|
||||
expect(axios.patch).toHaveBeenCalledWith(
|
||||
api.internalURL('/logs/subscribe'),
|
||||
{ enabled: true, clientId: undefined }
|
||||
)
|
||||
})
|
||||
|
||||
it('loads localized template indexes and fuse options defensively', async () => {
|
||||
const api = new ComfyApi()
|
||||
vi.mocked(axios.get)
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'template' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('missing locale'))
|
||||
.mockResolvedValueOnce({
|
||||
data: [{ name: 'fallback' }],
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('default missing'))
|
||||
.mockResolvedValueOnce({
|
||||
data: { keys: ['name'] },
|
||||
headers: { 'content-type': 'application/json' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: '<html></html>',
|
||||
headers: { 'content-type': 'text/html' }
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
data: { ignored: true },
|
||||
headers: {}
|
||||
})
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => undefined)
|
||||
|
||||
await expect(api.getCoreWorkflowTemplates('fr')).resolves.toEqual([
|
||||
{ name: 'template' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates()).resolves.toEqual([])
|
||||
await expect(api.getCoreWorkflowTemplates('zh')).resolves.toEqual([
|
||||
{ name: 'fallback' }
|
||||
])
|
||||
await expect(api.getCoreWorkflowTemplates('en')).resolves.toEqual([])
|
||||
await expect(api.getFuseOptions()).resolves.toEqual({ keys: ['name'] })
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
await expect(api.getFuseOptions()).resolves.toBeNull()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,12 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
CanvasPointerEvent,
|
||||
Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
|
||||
|
||||
const mockAssert = vi.hoisted(() => vi.fn())
|
||||
|
||||
@@ -14,7 +20,7 @@ const mockNodeOutputStore = vi.hoisted(() => ({
|
||||
}))
|
||||
|
||||
const mockSubgraphNavigationStore = vi.hoisted(() => ({
|
||||
exportState: vi.fn(() => []),
|
||||
exportState: vi.fn((): string[] => []),
|
||||
restoreState: vi.fn()
|
||||
}))
|
||||
|
||||
@@ -23,10 +29,24 @@ const mockWorkflowStore = vi.hoisted(() => ({
|
||||
getWorkflowByPath: vi.fn()
|
||||
}))
|
||||
|
||||
const mockExecutionStore = vi.hoisted(() => ({
|
||||
queuedJobs: {} as Record<string, { workflow: { changeTracker: unknown } }>
|
||||
}))
|
||||
|
||||
const mockMaskEditorIsOpened = vi.hoisted(() => vi.fn(() => false))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
constructor: {
|
||||
maskeditor_is_opended: mockMaskEditorIsOpened
|
||||
},
|
||||
graph: {},
|
||||
ui: {
|
||||
autoQueueEnabled: false,
|
||||
autoQueueMode: 'instant'
|
||||
},
|
||||
rootGraph: {
|
||||
subgraphs: new Map(),
|
||||
serialize: vi.fn(() => ({
|
||||
nodes: [],
|
||||
links: [],
|
||||
@@ -39,8 +59,10 @@ vi.mock('@/scripts/app', () => ({
|
||||
}))
|
||||
},
|
||||
canvas: {
|
||||
ds: { scale: 1, offset: [0, 0] }
|
||||
}
|
||||
ds: { scale: 1, offset: [0, 0] },
|
||||
setGraph: vi.fn()
|
||||
},
|
||||
loadGraphData: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -65,6 +87,10 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: vi.fn(() => mockWorkflowStore)
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/executionStore', () => ({
|
||||
useExecutionStore: vi.fn(() => mockExecutionStore)
|
||||
}))
|
||||
|
||||
import { app } from '@/scripts/app'
|
||||
import { api } from '@/scripts/api'
|
||||
import { ChangeTracker } from '@/scripts/changeTracker'
|
||||
@@ -107,13 +133,120 @@ function mockCanvasState(state: ComfyWorkflowJSON) {
|
||||
vi.mocked(app.rootGraph.serialize).mockReturnValue(state as never)
|
||||
}
|
||||
|
||||
type ListenerMap = Record<string, EventListener[]>
|
||||
|
||||
function storeListener(
|
||||
listeners: ListenerMap,
|
||||
type: string,
|
||||
listener: EventListenerOrEventListenerObject
|
||||
) {
|
||||
if (typeof listener === 'function') {
|
||||
listeners[type] ??= []
|
||||
listeners[type].push(listener)
|
||||
}
|
||||
}
|
||||
|
||||
function dispatchStored(listeners: ListenerMap, type: string, event: Event) {
|
||||
for (const listener of listeners[type] ?? []) {
|
||||
listener(event)
|
||||
}
|
||||
}
|
||||
|
||||
async function flushAsyncFrame() {
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
}
|
||||
|
||||
function getApiListener(name: string) {
|
||||
const call = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === name)
|
||||
expect(call).toBeDefined()
|
||||
return call?.[1] as (event: CustomEvent<ExecutedWsMessage>) => void
|
||||
}
|
||||
|
||||
describe('ChangeTracker', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
nodeIdCounter = 0
|
||||
ChangeTracker.isLoadingGraph = false
|
||||
Reflect.set(ChangeTracker, '_checkStateWarned', false)
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(null)
|
||||
mockExecutionStore.queuedJobs = {}
|
||||
mockMaskEditorIsOpened.mockReturnValue(false)
|
||||
app.ui.autoQueueEnabled = false
|
||||
app.ui.autoQueueMode = 'instant'
|
||||
vi.mocked(app.canvas.setGraph).mockClear()
|
||||
vi.mocked(app.loadGraphData).mockResolvedValue(undefined)
|
||||
app.rootGraph.subgraphs.clear()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
})
|
||||
|
||||
describe('reset', () => {
|
||||
it('updates initialState from activeState or an explicit state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
|
||||
tracker.activeState = changed
|
||||
tracker.reset()
|
||||
|
||||
expect(tracker.initialState).toEqual(changed)
|
||||
expect(tracker.initialState).not.toBe(changed)
|
||||
|
||||
const explicit = createState(3)
|
||||
tracker.reset(explicit)
|
||||
|
||||
expect(tracker.activeState).toEqual(explicit)
|
||||
expect(tracker.activeState).not.toBe(explicit)
|
||||
expect(tracker.initialState).toEqual(explicit)
|
||||
})
|
||||
|
||||
it('does not reset while restoring state', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const original = tracker.initialState
|
||||
tracker._restoringState = true
|
||||
|
||||
tracker.reset(createState(2))
|
||||
|
||||
expect(tracker.initialState).toBe(original)
|
||||
})
|
||||
})
|
||||
|
||||
describe('restore', () => {
|
||||
it('restores viewport, outputs, and root graph navigation', () => {
|
||||
const tracker = createTracker()
|
||||
app.canvas.ds.scale = 2
|
||||
app.canvas.ds.offset = [10, 20]
|
||||
mockNodeOutputStore.snapshotOutputs.mockReturnValue({ 1: { images: [] } })
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue([])
|
||||
|
||||
tracker.store()
|
||||
app.canvas.ds.scale = 1
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.ds.scale).toBe(2)
|
||||
expect(app.canvas.ds.offset).toEqual([10, 20])
|
||||
expect(mockNodeOutputStore.restoreOutputs).toHaveBeenCalledWith({
|
||||
1: { images: [] }
|
||||
})
|
||||
expect(mockSubgraphNavigationStore.restoreState).toHaveBeenCalledWith([])
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
})
|
||||
|
||||
it('restores saved subgraph navigation when the subgraph exists', () => {
|
||||
const tracker = createTracker()
|
||||
const subgraph = { id: 'subgraph-1' } as unknown as Subgraph
|
||||
app.rootGraph.subgraphs.set('subgraph-1', subgraph)
|
||||
mockSubgraphNavigationStore.exportState.mockReturnValue(['subgraph-1'])
|
||||
|
||||
tracker.store()
|
||||
tracker.restore()
|
||||
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
|
||||
})
|
||||
})
|
||||
|
||||
describe('captureCanvasState', () => {
|
||||
@@ -169,9 +302,32 @@ describe('ChangeTracker', () => {
|
||||
expect.stringContaining('captureCanvasState')
|
||||
)
|
||||
})
|
||||
|
||||
it('reports inactive tracker calls only once for the same workflow', () => {
|
||||
const tracker = createTracker()
|
||||
tracker.workflow.path = '/test/dedupe-workflow.json'
|
||||
mockWorkflowStore.activeWorkflow = { changeTracker: {} }
|
||||
|
||||
tracker.captureCanvasState()
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(mockAssert).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('state capture', () => {
|
||||
it('sets the active state without pushing undo when none exists yet', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const changed = createState(2)
|
||||
tracker.activeState = undefined as never
|
||||
mockCanvasState(changed)
|
||||
|
||||
tracker.captureCanvasState()
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
expect(tracker.undoQueue).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('pushes to undoQueue, updates activeState, and calls updateModified', () => {
|
||||
const initial = createState(1)
|
||||
const tracker = createTracker(initial)
|
||||
@@ -238,6 +394,19 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.undoQueue).toHaveLength(ChangeTracker.MAX_HISTORY)
|
||||
})
|
||||
|
||||
it('does not capture until the outer change transaction finishes', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
tracker.beforeChange()
|
||||
tracker.beforeChange()
|
||||
mockCanvasState(createState(2))
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).not.toHaveBeenCalled()
|
||||
|
||||
tracker.afterChange()
|
||||
expect(app.rootGraph.serialize).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -302,6 +471,105 @@ describe('ChangeTracker', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateModified', () => {
|
||||
it('updates workflow modified state when the store can find it', () => {
|
||||
const state = createState(1)
|
||||
const tracker = createTracker(state)
|
||||
const workflow = { isModified: true }
|
||||
mockWorkflowStore.getWorkflowByPath.mockReturnValue(workflow)
|
||||
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(false)
|
||||
|
||||
tracker.activeState = createState(2)
|
||||
tracker.updateModified()
|
||||
expect(workflow.isModified).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('undo and redo', () => {
|
||||
it('restores previous state and moves the current state to the target queue', async () => {
|
||||
const initial = createState(1)
|
||||
const changed = createState(2)
|
||||
const tracker = createTracker(changed)
|
||||
tracker.undoQueue.push(initial)
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith(
|
||||
initial,
|
||||
false,
|
||||
false,
|
||||
tracker.workflow,
|
||||
{
|
||||
checkForRerouteMigration: false,
|
||||
silentAssetErrors: true
|
||||
}
|
||||
)
|
||||
expect(tracker.activeState).toBe(initial)
|
||||
expect(tracker.redoQueue).toEqual([changed])
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('clears restoring state when loading fails', async () => {
|
||||
const tracker = createTracker(createState(2))
|
||||
tracker.undoQueue.push(createState(1))
|
||||
vi.mocked(app.loadGraphData).mockRejectedValueOnce(
|
||||
new Error('load failed')
|
||||
)
|
||||
|
||||
await expect(tracker.undo()).rejects.toThrow('load failed')
|
||||
|
||||
expect(tracker._restoringState).toBe(false)
|
||||
})
|
||||
|
||||
it('does nothing when no previous state exists', async () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
|
||||
await tracker.undo()
|
||||
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('handles keyboard undo and redo shortcuts', async () => {
|
||||
const tracker = createTracker()
|
||||
const undo = vi.spyOn(tracker, 'undo').mockResolvedValue()
|
||||
const redo = vi.spyOn(tracker, 'redo').mockResolvedValue()
|
||||
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
shiftKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', { key: 'y', metaKey: true })
|
||||
)
|
||||
).resolves.toBe(true)
|
||||
await expect(
|
||||
tracker.undoRedo(
|
||||
new KeyboardEvent('keydown', {
|
||||
key: 'z',
|
||||
ctrlKey: true,
|
||||
altKey: true
|
||||
})
|
||||
)
|
||||
).resolves.toBeUndefined()
|
||||
|
||||
expect(undo).toHaveBeenCalledOnce()
|
||||
expect(redo).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkState (deprecated)', () => {
|
||||
it('delegates to captureCanvasState', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
@@ -312,5 +580,389 @@ describe('ChangeTracker', () => {
|
||||
|
||||
expect(tracker.activeState).toEqual(changed)
|
||||
})
|
||||
|
||||
it('warns only once before delegating', () => {
|
||||
const tracker = createTracker(createState(1))
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => undefined)
|
||||
|
||||
tracker.checkState()
|
||||
tracker.checkState()
|
||||
|
||||
expect(warn).toHaveBeenCalledOnce()
|
||||
warn.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('bindInput', () => {
|
||||
it('returns false for missing canvas or body elements', () => {
|
||||
expect(ChangeTracker.bindInput(null)).toBe(false)
|
||||
expect(ChangeTracker.bindInput(document.createElement('canvas'))).toBe(
|
||||
false
|
||||
)
|
||||
expect(ChangeTracker.bindInput(document.body)).toBe(false)
|
||||
})
|
||||
|
||||
it('captures state once when an input-like element changes', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const input = document.createElement('input')
|
||||
|
||||
expect(ChangeTracker.bindInput(input)).toBe(true)
|
||||
input.dispatchEvent(new Event('change'))
|
||||
input.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('binds textarea-like elements that expose an input handler slot', () => {
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
const element = document.createElement('div') as HTMLElement & {
|
||||
oninput: unknown
|
||||
}
|
||||
element.oninput = null
|
||||
|
||||
expect(ChangeTracker.bindInput(element)).toBe(true)
|
||||
element.dispatchEvent(new Event('change'))
|
||||
|
||||
expect(capture).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
|
||||
describe('init', () => {
|
||||
it('captures changes from registered browser, graph, and API events', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const documentListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(documentListeners, type, listener)
|
||||
})
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
const processMouseUp = vi.fn(() => true)
|
||||
const prompt = vi.fn()
|
||||
const close = vi.fn(() => true)
|
||||
const originalProcessMouseUp = LGraphCanvas.prototype.processMouseUp
|
||||
const originalPrompt = LGraphCanvas.prototype.prompt
|
||||
const originalClose = LiteGraph.ContextMenu.prototype.close
|
||||
|
||||
LGraphCanvas.prototype.processMouseUp = processMouseUp
|
||||
LGraphCanvas.prototype.prompt = prompt
|
||||
LiteGraph.ContextMenu.prototype.close = close
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(windowListeners, 'mouseup', new MouseEvent('mouseup'))
|
||||
getApiListener('promptQueued')(
|
||||
new CustomEvent('promptQueued', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
getApiListener('graphCleared')(
|
||||
new CustomEvent('graphCleared', {
|
||||
detail: {} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'before-change' }
|
||||
})
|
||||
)
|
||||
dispatchStored(
|
||||
documentListeners,
|
||||
'litegraph:canvas',
|
||||
new CustomEvent('litegraph:canvas', {
|
||||
detail: { subType: 'after-change' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(4)
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'Control' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
dispatchStored(windowListeners, 'keyup', new KeyboardEvent('keyup'))
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
const undoRedo = vi.spyOn(tracker, 'undoRedo').mockResolvedValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'z', ctrlKey: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(undoRedo).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(5)
|
||||
|
||||
undoRedo.mockResolvedValue(undefined)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'a' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const input = document.createElement('input')
|
||||
document.body.append(input)
|
||||
input.focus()
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'b' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
input.remove()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
mockMaskEditorIsOpened.mockReturnValue(true)
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'c' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
|
||||
expect(capture).toHaveBeenCalledTimes(6)
|
||||
|
||||
const canvas = {} as LGraphCanvas
|
||||
LGraphCanvas.prototype.processMouseUp.call(
|
||||
canvas,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
|
||||
expect(processMouseUp).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(7)
|
||||
|
||||
const promptCallback = vi.fn()
|
||||
LGraphCanvas.prototype.prompt.call(
|
||||
canvas,
|
||||
'title',
|
||||
'value',
|
||||
promptCallback,
|
||||
new MouseEvent('mouseup') as CanvasPointerEvent
|
||||
)
|
||||
const extendedCallback = prompt.mock.calls[0]?.[2] as
|
||||
| ((value: string) => void)
|
||||
| undefined
|
||||
extendedCallback?.('updated')
|
||||
|
||||
expect(promptCallback).toHaveBeenCalledWith('updated')
|
||||
expect(capture).toHaveBeenCalledTimes(8)
|
||||
|
||||
LiteGraph.ContextMenu.prototype.close.call(
|
||||
{} as InstanceType<typeof LiteGraph.ContextMenu>,
|
||||
new MouseEvent('mouseup')
|
||||
)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(capture).toHaveBeenCalledTimes(9)
|
||||
} finally {
|
||||
LGraphCanvas.prototype.processMouseUp = originalProcessMouseUp
|
||||
LGraphCanvas.prototype.prompt = originalPrompt
|
||||
LiteGraph.ContextMenu.prototype.close = originalClose
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('ignores repeat keydowns and missing active trackers', async () => {
|
||||
const windowListeners: ListenerMap = {}
|
||||
const windowAddSpy = vi
|
||||
.spyOn(window, 'addEventListener')
|
||||
.mockImplementation((type, listener) => {
|
||||
storeListener(windowListeners, type, listener)
|
||||
})
|
||||
const documentAddSpy = vi
|
||||
.spyOn(document, 'addEventListener')
|
||||
.mockImplementation(() => undefined)
|
||||
const rafSpy = vi
|
||||
.spyOn(window, 'requestAnimationFrame')
|
||||
.mockImplementation((callback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
|
||||
try {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const capture = vi.spyOn(tracker, 'captureCanvasState')
|
||||
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x', repeat: true })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
|
||||
mockWorkflowStore.activeWorkflow = null
|
||||
dispatchStored(
|
||||
windowListeners,
|
||||
'keydown',
|
||||
new KeyboardEvent('keydown', { key: 'x' })
|
||||
)
|
||||
await flushAsyncFrame()
|
||||
expect(capture).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
windowAddSpy.mockRestore()
|
||||
documentAddSpy.mockRestore()
|
||||
rafSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('stores executed outputs for the workflow that owns the prompt', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { images: ['first'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { images: ['second'], text: ['caption'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'missing',
|
||||
node: '2',
|
||||
output: { images: ['ignored'] }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { images: ['first', 'second'], text: ['caption'] }
|
||||
})
|
||||
})
|
||||
|
||||
it('replaces non-array executed outputs during merge updates', () => {
|
||||
ChangeTracker.init()
|
||||
const tracker = createTracker()
|
||||
const executed = getApiListener('executed')
|
||||
mockExecutionStore.queuedJobs = {
|
||||
promptA: { workflow: { changeTracker: tracker } }
|
||||
}
|
||||
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
output: { value: 'old' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
executed(
|
||||
new CustomEvent('executed', {
|
||||
detail: {
|
||||
prompt_id: 'promptA',
|
||||
node: '1',
|
||||
merge: true,
|
||||
output: { value: 'new' }
|
||||
} as unknown as ExecutedWsMessage
|
||||
})
|
||||
)
|
||||
|
||||
expect(tracker.nodeOutputs).toEqual({
|
||||
1: { value: 'new' }
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('graphEqual', () => {
|
||||
it('compares workflow nodes as an unordered set and ignores extra.ds', () => {
|
||||
const first = createState(2)
|
||||
const second = {
|
||||
...createState(),
|
||||
nodes: [...first.nodes].reverse(),
|
||||
links: first.links,
|
||||
groups: first.groups,
|
||||
extra: { ds: { scale: 2 } }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, first)).toBe(true)
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for non-object values and meaningful graph differences', () => {
|
||||
const first = createState(1)
|
||||
const differentNodes = createState(2)
|
||||
const differentLinks = {
|
||||
...first,
|
||||
links: [[1, 1, 0, 2, 0, 'MODEL']]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, null as never)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentNodes)).toBe(false)
|
||||
expect(ChangeTracker.graphEqual(first, differentLinks)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false for extra properties other than viewport state', () => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
extra: { custom: true }
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
|
||||
it.each([
|
||||
'floatingLinks',
|
||||
'reroutes',
|
||||
'groups',
|
||||
'definitions',
|
||||
'subgraphs'
|
||||
] as const)('returns false when %s differs', (key) => {
|
||||
const first = createState()
|
||||
const second = {
|
||||
...first,
|
||||
[key]: [{ id: 1 }]
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
|
||||
expect(ChangeTracker.graphEqual(first, second)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
import { describe, expect, test, vi } from 'vitest'
|
||||
|
||||
import type { LGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { ComponentWidgetImpl, DOMWidgetImpl } from '@/scripts/domWidget'
|
||||
import {
|
||||
addWidget,
|
||||
ComponentWidgetImpl,
|
||||
DOMWidgetImpl,
|
||||
isComponentWidget,
|
||||
isDOMWidget
|
||||
} from '@/scripts/domWidget'
|
||||
|
||||
const { registerWidget, unregisterWidget } = vi.hoisted(() => ({
|
||||
registerWidget: vi.fn(),
|
||||
unregisterWidget: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/domWidgetStore', () => ({
|
||||
useDomWidgetStore: () => ({
|
||||
unregisterWidget: vi.fn()
|
||||
registerWidget,
|
||||
unregisterWidget
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -24,13 +37,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 66
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(66)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.name).toBe('test-widget')
|
||||
@@ -48,13 +59,11 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Set a specific Y position
|
||||
originalWidget.y = 42
|
||||
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved
|
||||
expect(clonedWidget.y).toBe(42)
|
||||
expect(clonedWidget.node).toBe(newNode)
|
||||
expect(clonedWidget.element).toBe(mockElement)
|
||||
@@ -71,11 +80,9 @@ describe('DOMWidget Y Position Preservation', () => {
|
||||
options: {}
|
||||
})
|
||||
|
||||
// Don't explicitly set Y (should be 0 by default)
|
||||
const newNode = new LGraphNode('new-node')
|
||||
const clonedWidget = originalWidget.createCopyForNode(newNode)
|
||||
|
||||
// Verify Y position is preserved (should be 0)
|
||||
expect(clonedWidget.y).toBe(0)
|
||||
})
|
||||
})
|
||||
@@ -96,3 +103,271 @@ describe('BaseDOMWidgetImpl.isVisible', () => {
|
||||
expect(widget.isVisible()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('DOMWidgetImpl', () => {
|
||||
test('identifies DOM and component widgets', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const domWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'dom',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const componentWidget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(isDOMWidget(domWidget)).toBe(true)
|
||||
expect(isDOMWidget(componentWidget)).toBe(false)
|
||||
expect(isComponentWidget(componentWidget)).toBe(true)
|
||||
expect(isComponentWidget(domWidget)).toBe(false)
|
||||
})
|
||||
|
||||
test('uses option-backed values, callbacks, and margins', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
const callback = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getValue: () => value,
|
||||
setValue,
|
||||
margin: 4
|
||||
}
|
||||
})
|
||||
widget.callback = callback
|
||||
|
||||
widget.value = 'next'
|
||||
|
||||
expect(widget.value).toBe('next')
|
||||
expect(widget.margin).toBe(4)
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
test('uses default value and margin when options do not provide them', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
|
||||
expect(widget.value).toBe('')
|
||||
expect(widget.margin).toBe(10)
|
||||
})
|
||||
|
||||
test('draws zoom placeholders and delegates visible draws', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(true)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
margin: 5,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).toHaveBeenCalledWith(5, 15, 90, 30)
|
||||
expect(ctx.fill).toHaveBeenCalledOnce()
|
||||
expect(ctx.fillStyle).toBe('#000')
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('skips placeholder drawing when hidden', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
vi.spyOn(node, 'isWidgetVisible').mockReturnValue(false)
|
||||
const onDraw = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
hideOnZoom: true,
|
||||
onDraw
|
||||
}
|
||||
})
|
||||
const ctx = {
|
||||
beginPath: vi.fn(),
|
||||
fill: vi.fn(),
|
||||
fillStyle: '#000',
|
||||
rect: vi.fn()
|
||||
} as unknown as CanvasRenderingContext2D
|
||||
|
||||
widget.draw(ctx, node, 100, 10, 40, true)
|
||||
|
||||
expect(ctx.rect).not.toHaveBeenCalled()
|
||||
expect(onDraw).toHaveBeenCalledWith(widget)
|
||||
})
|
||||
|
||||
test('computes hidden, option, percent, and fallback layout sizes', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
node.size = [100, 200]
|
||||
const hiddenWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'hidden',
|
||||
type: 'hidden',
|
||||
element: document.createElement('textarea'),
|
||||
options: {}
|
||||
})
|
||||
const optionWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'option',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 11,
|
||||
getMaxHeight: () => 88,
|
||||
getHeight: () => 44
|
||||
}
|
||||
})
|
||||
const percentWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'percent',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getMinHeight: () => 10,
|
||||
getMaxHeight: () => 60,
|
||||
getHeight: () => '25%'
|
||||
}
|
||||
})
|
||||
const fallbackWidget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'fallback',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
getHeight: () => 40
|
||||
}
|
||||
})
|
||||
|
||||
expect(hiddenWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 0,
|
||||
maxHeight: 0,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(optionWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 11,
|
||||
maxHeight: 88,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(percentWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 10,
|
||||
maxHeight: 60,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(fallbackWidget.computeLayoutSize(node)).toEqual({
|
||||
minHeight: 40,
|
||||
maxHeight: undefined,
|
||||
minWidth: 0
|
||||
})
|
||||
})
|
||||
|
||||
test('registers widgets immediately and through node lifecycle callbacks', () => {
|
||||
registerWidget.mockClear()
|
||||
unregisterWidget.mockClear()
|
||||
const node = new LGraphNode('test-node')
|
||||
node.graph = {} as LGraph
|
||||
const beforeResize = vi.fn()
|
||||
const afterResize = vi.fn()
|
||||
const widget = new DOMWidgetImpl({
|
||||
node,
|
||||
name: 'text',
|
||||
type: 'text',
|
||||
element: document.createElement('textarea'),
|
||||
options: {
|
||||
beforeResize,
|
||||
afterResize
|
||||
}
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
addWidget(node, widget)
|
||||
node.onAdded?.(node.graph)
|
||||
node.onResize?.([0, 0])
|
||||
node.onRemoved?.()
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledWith(widget)
|
||||
expect(registerWidget).toHaveBeenCalledTimes(2)
|
||||
expect(beforeResize).toHaveBeenCalledWith(node)
|
||||
expect(afterResize).toHaveBeenCalledWith(node)
|
||||
expect(unregisterWidget).toHaveBeenCalledWith(widget.id)
|
||||
})
|
||||
|
||||
test('computes component layout and serializes raw values', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const value = { nested: true }
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'component',
|
||||
component: { template: '<div />' },
|
||||
inputSpec: { name: 'component', type: 'STRING' },
|
||||
options: {
|
||||
getValue: () => value,
|
||||
getMinHeight: () => 12,
|
||||
getMaxHeight: () => 48
|
||||
}
|
||||
})
|
||||
|
||||
expect(widget.computeLayoutSize()).toEqual({
|
||||
minHeight: 12,
|
||||
maxHeight: 48,
|
||||
minWidth: 0
|
||||
})
|
||||
expect(widget.serializeValue()).toEqual({ nested: true })
|
||||
})
|
||||
|
||||
test('adds DOM widgets through LGraphNode prototype helper', () => {
|
||||
const node = new LGraphNode('test-node')
|
||||
const element = document.createElement('textarea')
|
||||
let value = 'initial'
|
||||
const setValue = vi.fn((next: string) => {
|
||||
value = next
|
||||
})
|
||||
vi.spyOn(node, 'addCustomWidget')
|
||||
|
||||
const widget = node.addDOMWidget('text', 'textarea', element, {
|
||||
getValue: () => value,
|
||||
setValue
|
||||
})
|
||||
const callback = vi.fn()
|
||||
widget.callback = callback
|
||||
widget.value = 'next'
|
||||
|
||||
expect(node.addCustomWidget).toHaveBeenCalledWith(widget)
|
||||
expect(widget.element).toBe(element)
|
||||
expect(widget.options.hideOnZoom).toBe(true)
|
||||
expect(widget.value).toBe('next')
|
||||
expect(setValue).toHaveBeenCalledWith('next')
|
||||
expect(callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
})
|
||||
|
||||
106
src/scripts/errorNodeWidgets.test.ts
Normal file
106
src/scripts/errorNodeWidgets.test.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
interface TestWidgetOptions {
|
||||
name: string
|
||||
type: string
|
||||
default?: unknown
|
||||
multiline?: boolean
|
||||
}
|
||||
|
||||
function createWidgetFactory() {
|
||||
return (node: LGraphNode, options: TestWidgetOptions): IBaseWidget => {
|
||||
const widget: IBaseWidget = fromAny({
|
||||
name: options.name,
|
||||
type: options.type,
|
||||
value: options.default,
|
||||
options
|
||||
})
|
||||
node.widgets = [...(node.widgets ?? []), widget]
|
||||
return widget
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({
|
||||
useStringWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({
|
||||
useFloatWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({
|
||||
useBooleanWidget: () => createWidgetFactory()
|
||||
})
|
||||
)
|
||||
|
||||
import './errorNodeWidgets'
|
||||
|
||||
describe('errorNodeWidgets', () => {
|
||||
it('restores widgets from serialized values on error nodes', () => {
|
||||
const node = new LGraphNode('BrokenNode')
|
||||
const longText = 'serialized value with more than twenty chars'
|
||||
node.has_errors = true
|
||||
|
||||
node.onConfigure?.(
|
||||
fromAny({
|
||||
widgets_values: {
|
||||
length: 5,
|
||||
0: 'short text',
|
||||
1: longText,
|
||||
2: 12,
|
||||
3: true,
|
||||
4: { nested: 'value' }
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toHaveLength(5)
|
||||
expect(node.widgets?.map((widget) => widget.name)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN_1',
|
||||
'UNKNOWN_2',
|
||||
'UNKNOWN_3',
|
||||
'UNKNOWN_4'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.label)).toEqual([
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN',
|
||||
'UNKNOWN'
|
||||
])
|
||||
expect(node.widgets?.map((widget) => widget.value)).toEqual([
|
||||
'short text',
|
||||
longText,
|
||||
12,
|
||||
true,
|
||||
'{"nested":"value"}'
|
||||
])
|
||||
expect(node.serialize_widgets).toBe(true)
|
||||
})
|
||||
|
||||
it('leaves normal nodes unchanged', () => {
|
||||
const node = new LGraphNode('HealthyNode')
|
||||
|
||||
node.onConfigure?.(
|
||||
fromPartial({
|
||||
widgets_values: ['ignored']
|
||||
})
|
||||
)
|
||||
|
||||
expect(node.widgets).toBeUndefined()
|
||||
expect(node.serialize_widgets).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromAvifFile } from './avif'
|
||||
|
||||
@@ -83,6 +84,11 @@ describe('AVIF metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromAvifFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -139,7 +145,8 @@ const buildInfeBox = (
|
||||
itemType: string,
|
||||
version = 2
|
||||
): Uint8Array => {
|
||||
const bodySize = 4 + 2 + 2 + 4 + 1 + 1
|
||||
const itemIdSize = version === 2 ? 2 : version >= 3 ? 4 : 0
|
||||
const bodySize = 4 + itemIdSize + (version >= 2 ? 2 + 4 + 1 + 1 : 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
@@ -147,22 +154,36 @@ const buildInfeBox = (
|
||||
buf.set(new TextEncoder().encode('infe'), 4)
|
||||
buf[8] = version
|
||||
if (version >= 2) {
|
||||
setU16BE(dv, 12, itemId)
|
||||
setU16BE(dv, 14, 0)
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), 16)
|
||||
let p = 12
|
||||
if (version === 2) {
|
||||
setU16BE(dv, p, itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, itemId)
|
||||
p += 4
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
buf.set(new TextEncoder().encode(itemType.padEnd(4).slice(0, 4)), p)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
const bodySize = 4 + 2 + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const buildIinfBox = (infeBoxes: Uint8Array[], version = 0): Uint8Array => {
|
||||
const countSize = version === 0 ? 2 : 4
|
||||
const bodySize = 4 + countSize + infeBoxes.reduce((s, b) => s + b.length, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iinf'), 4)
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
let off = 14
|
||||
buf[8] = version
|
||||
if (version === 0) {
|
||||
setU16BE(dv, 12, infeBoxes.length)
|
||||
} else {
|
||||
setU32BE(dv, 12, infeBoxes.length)
|
||||
}
|
||||
let off = 8 + 4 + countSize
|
||||
for (const ib of infeBoxes) {
|
||||
buf.set(ib, off)
|
||||
off += ib.length
|
||||
@@ -170,31 +191,91 @@ const buildIinfBox = (infeBoxes: Uint8Array[]): Uint8Array => {
|
||||
return buf
|
||||
}
|
||||
|
||||
interface IlocItem {
|
||||
itemId: number
|
||||
extentOffset: number
|
||||
extentLength: number
|
||||
extents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
}
|
||||
|
||||
const buildIlocBox = (
|
||||
items: { itemId: number; extentOffset: number; extentLength: number }[]
|
||||
items: IlocItem[],
|
||||
{
|
||||
version = 0,
|
||||
baseOffsetSize = 0,
|
||||
indexSize = 0
|
||||
}: { version?: number; baseOffsetSize?: number; indexSize?: number } = {}
|
||||
): Uint8Array => {
|
||||
const perItemSize = 2 + 2 + 0 + 2 + (4 + 4)
|
||||
const bodySize = 4 + 1 + 1 + 2 + items.length * perItemSize
|
||||
const itemCountSize = version < 2 ? 2 : 4
|
||||
const itemIdSize = version < 2 ? 2 : 4
|
||||
const constructionMethodSize = version === 1 || version === 2 ? 2 : 0
|
||||
const itemSizes = items.map((item) => {
|
||||
const extents = item.extents ?? [
|
||||
{
|
||||
extentOffset: item.extentOffset,
|
||||
extentLength: item.extentLength
|
||||
}
|
||||
]
|
||||
return (
|
||||
itemIdSize +
|
||||
constructionMethodSize +
|
||||
2 +
|
||||
baseOffsetSize +
|
||||
2 +
|
||||
extents.length * (indexSize + 4 + 4)
|
||||
)
|
||||
})
|
||||
const bodySize =
|
||||
4 + 1 + 1 + itemCountSize + itemSizes.reduce((sum, size) => sum + size, 0)
|
||||
const totalSize = 8 + bodySize
|
||||
const buf = new Uint8Array(totalSize)
|
||||
const dv = new DataView(buf.buffer)
|
||||
setU32BE(dv, 0, totalSize)
|
||||
buf.set(new TextEncoder().encode('iloc'), 4)
|
||||
buf[8] = version
|
||||
buf[12] = 0x44
|
||||
buf[13] = 0x00
|
||||
setU16BE(dv, 14, items.length)
|
||||
let p = 16
|
||||
for (const it of items) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
buf[13] = (baseOffsetSize << 4) | indexSize
|
||||
let p = 14
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, items.length)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, items.length)
|
||||
p += 4
|
||||
}
|
||||
for (const it of items) {
|
||||
if (version < 2) {
|
||||
setU16BE(dv, p, it.itemId)
|
||||
p += 2
|
||||
} else {
|
||||
setU32BE(dv, p, it.itemId)
|
||||
p += 4
|
||||
}
|
||||
if (version === 1 || version === 2) {
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
}
|
||||
setU16BE(dv, p, 0)
|
||||
p += 2
|
||||
setU16BE(dv, p, 1)
|
||||
if (baseOffsetSize > 0) {
|
||||
setU32BE(dv, p, 0)
|
||||
p += baseOffsetSize
|
||||
}
|
||||
const extents = it.extents ?? [
|
||||
{
|
||||
extentOffset: it.extentOffset,
|
||||
extentLength: it.extentLength
|
||||
}
|
||||
]
|
||||
setU16BE(dv, p, extents.length)
|
||||
p += 2
|
||||
setU32BE(dv, p, it.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, it.extentLength)
|
||||
p += 4
|
||||
for (const extent of extents) {
|
||||
p += indexSize
|
||||
setU32BE(dv, p, extent.extentOffset)
|
||||
p += 4
|
||||
setU32BE(dv, p, extent.extentLength)
|
||||
p += 4
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
@@ -231,7 +312,13 @@ interface BuildAvifOpts {
|
||||
ftypBrand?: string
|
||||
omitMeta?: boolean
|
||||
omitIloc?: boolean
|
||||
iinfVersion?: number
|
||||
infeVersion?: number
|
||||
ilocVersion?: number
|
||||
ilocBaseOffsetSize?: number
|
||||
ilocIndexSize?: number
|
||||
ilocExtents?: Array<{ extentOffset: number; extentLength: number }>
|
||||
rawExifData?: Uint8Array
|
||||
}
|
||||
|
||||
const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
@@ -242,7 +329,13 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
ftypBrand = 'avif',
|
||||
omitMeta = false,
|
||||
omitIloc = false,
|
||||
infeVersion = 2
|
||||
iinfVersion = 0,
|
||||
infeVersion = 2,
|
||||
ilocVersion = 0,
|
||||
ilocBaseOffsetSize = 0,
|
||||
ilocIndexSize = 0,
|
||||
ilocExtents,
|
||||
rawExifData
|
||||
} = opts
|
||||
|
||||
const ftyp = buildFtypBox(ftypBrand)
|
||||
@@ -250,19 +343,43 @@ const buildAvifFile = (opts: BuildAvifOpts = {}): ArrayBuffer => {
|
||||
return ftyp.slice().buffer as ArrayBuffer
|
||||
}
|
||||
|
||||
const exifData = buildExifBlob(exifEntries, endian)
|
||||
const exifData = rawExifData ?? buildExifBlob(exifEntries, endian)
|
||||
const infe = buildInfeBox(1, itemType, infeVersion)
|
||||
const iinf = buildIinfBox([infe])
|
||||
const iinf = buildIinfBox([infe], iinfVersion)
|
||||
|
||||
const realIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: 0, extentLength: exifData.length }
|
||||
])
|
||||
const realIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: 0,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const metaSize = 8 + 4 + iinf.length + (omitIloc ? 0 : realIloc.length)
|
||||
const exifOffset = ftyp.length + metaSize
|
||||
|
||||
const finalIloc = buildIlocBox([
|
||||
{ itemId: 1, extentOffset: exifOffset, extentLength: exifData.length }
|
||||
])
|
||||
const finalIloc = buildIlocBox(
|
||||
[
|
||||
{
|
||||
itemId: 1,
|
||||
extentOffset: exifOffset,
|
||||
extentLength: exifData.length,
|
||||
extents: ilocExtents
|
||||
}
|
||||
],
|
||||
{
|
||||
version: ilocVersion,
|
||||
baseOffsetSize: ilocBaseOffsetSize,
|
||||
indexSize: ilocIndexSize
|
||||
}
|
||||
)
|
||||
const finalInner = omitIloc ? [iinf] : [iinf, finalIloc]
|
||||
const meta = buildMetaBox(finalInner)
|
||||
|
||||
@@ -319,6 +436,52 @@ describe('getFromAvifFile', () => {
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('extracts EXIF metadata from versioned item info and location boxes', async () => {
|
||||
const workflow = '{"versioned":true}'
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: [`workflow:${workflow}`],
|
||||
iinfVersion: 1,
|
||||
infeVersion: 3,
|
||||
ilocVersion: 2,
|
||||
ilocBaseOffsetSize: 4,
|
||||
ilocIndexSize: 4
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result.workflow).toBe(JSON.stringify(JSON.parse(workflow)))
|
||||
})
|
||||
|
||||
it('returns {} when the Exif item has no extents', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
exifEntries: ['workflow:{}'],
|
||||
ilocExtents: []
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns {} when the Exif payload has no TIFF header', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({
|
||||
rawExifData: new TextEncoder().encode('not tiff data')
|
||||
})
|
||||
)
|
||||
|
||||
const result = await getFromAvifFile(file)
|
||||
|
||||
expect(result).toEqual({})
|
||||
expect(console.log).toHaveBeenCalledWith(
|
||||
'Warning: TIFF header not found in EXIF data.'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns {} when AVIF major brand is not "avif"', async () => {
|
||||
const file = fileFromBuffer(
|
||||
buildAvifFile({ exifEntries: ['workflow:{}'], ftypBrand: 'heic' })
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromWebmFile } from './ebml'
|
||||
|
||||
@@ -16,6 +17,56 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.webm'
|
||||
)
|
||||
const WEBM_SIGNATURE = new Uint8Array([0x1a, 0x45, 0xdf, 0xa3])
|
||||
const SIMPLE_TAG = new Uint8Array([0x67, 0xc8])
|
||||
const TAG_NAME = new Uint8Array([0x45, 0xa3])
|
||||
const TAG_VALUE = new Uint8Array([0x44, 0x87])
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function bytes(...values: number[]) {
|
||||
return new Uint8Array(values)
|
||||
}
|
||||
|
||||
function vint(value: number) {
|
||||
return bytes(0x80 | value)
|
||||
}
|
||||
|
||||
function element(id: Uint8Array, value: string) {
|
||||
const encoded = encoder.encode(value)
|
||||
return concatBytes(id, vint(encoded.length), encoded)
|
||||
}
|
||||
|
||||
function simpleTag(name: string, value: string, useTwoByteSize = false) {
|
||||
const payload = concatBytes(
|
||||
element(TAG_NAME, name),
|
||||
element(TAG_VALUE, value)
|
||||
)
|
||||
const size = useTwoByteSize
|
||||
? bytes(0x40, payload.length)
|
||||
: vint(payload.length)
|
||||
return concatBytes(SIMPLE_TAG, size, payload)
|
||||
}
|
||||
|
||||
async function readWebm(bytes: Uint8Array) {
|
||||
return getFromWebmFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.webm', {
|
||||
type: 'video/webm'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('WebM/EBML metadata', () => {
|
||||
it('extracts workflow and prompt from EBML SimpleTag elements', async () => {
|
||||
@@ -46,6 +97,89 @@ describe('WebM/EBML metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts plain string tags and trims null-terminated values', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('\0Comment', ' hello \0ignored', true)
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.comment).toBe('hello')
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value is not complete JSON', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', '{not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('ignores prompt tags whose value has no JSON object', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, simpleTag('PROMPT', 'not-json'))
|
||||
)
|
||||
|
||||
expect(result.prompt).toBeUndefined()
|
||||
})
|
||||
|
||||
it('parses the first complete JSON object from a prompt tag', async () => {
|
||||
const result = await readWebm(
|
||||
concatBytes(
|
||||
WEBM_SIGNATURE,
|
||||
simpleTag('PROMPT', 'prefix {"outer":{"inner":1}} trailing')
|
||||
)
|
||||
)
|
||||
|
||||
expect(result.prompt).toEqual({ outer: { inner: 1 } })
|
||||
})
|
||||
|
||||
it('ignores tags whose name has no readable text', async () => {
|
||||
const payload = concatBytes(
|
||||
concatBytes(TAG_NAME, vint(2), bytes(0, 1)),
|
||||
element(TAG_VALUE, 'value')
|
||||
)
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores tag elements with zero-sized names', async () => {
|
||||
const payload = concatBytes(TAG_NAME, vint(0), element(TAG_VALUE, 'value'))
|
||||
|
||||
const result = await readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(payload.length), payload)
|
||||
)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores malformed SimpleTag encodings', async () => {
|
||||
const nameOnly = element(TAG_NAME, 'comment')
|
||||
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x00)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, bytes(0x7f, 0xff)))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(10), TAG_NAME))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readWebm(
|
||||
concatBytes(WEBM_SIGNATURE, SIMPLE_TAG, vint(nameOnly.length), nameOnly)
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -60,5 +194,10 @@ describe('WebM/EBML metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromWebmFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
@@ -5,7 +6,8 @@ import { ASCII, GltfSizeBytes } from '@/types/metadataTypes'
|
||||
import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getGltfBinaryMetadata } from './gltf'
|
||||
|
||||
@@ -16,11 +18,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
return { header, headerView }
|
||||
}
|
||||
|
||||
const createJSONChunk = (jsonData: ArrayBuffer) => {
|
||||
const createJSONChunk = (
|
||||
jsonData: ArrayBuffer,
|
||||
chunkType = ASCII.JSON,
|
||||
chunkLength = jsonData.byteLength
|
||||
) => {
|
||||
const chunkHeader = new ArrayBuffer(GltfSizeBytes.CHUNK_HEADER)
|
||||
const chunkView = new DataView(chunkHeader)
|
||||
chunkView.setUint32(0, jsonData.byteLength, true)
|
||||
chunkView.setUint32(4, ASCII.JSON, true)
|
||||
chunkView.setUint32(0, chunkLength, true)
|
||||
chunkView.setUint32(4, chunkType, true)
|
||||
return chunkHeader
|
||||
}
|
||||
|
||||
@@ -52,13 +58,27 @@ describe('GLTF binary metadata parser', () => {
|
||||
// Builds a GLB whose JSON chunk is the literal text passed in - used to
|
||||
// embed Python generated bare NaN/Infinity tokens that JSON.stringify
|
||||
// would otherwise coerce to null.
|
||||
function createMockGltfFileFromText(jsonText: string): File {
|
||||
interface MockGltfFileOptions {
|
||||
chunkLength?: number
|
||||
chunkType?: number
|
||||
magicNumber?: number
|
||||
}
|
||||
|
||||
function createMockGltfFileFromText(
|
||||
jsonText: string,
|
||||
{
|
||||
chunkLength,
|
||||
chunkType = ASCII.JSON,
|
||||
magicNumber = ASCII.GLTF
|
||||
}: MockGltfFileOptions = {}
|
||||
): File {
|
||||
const jsonData = new TextEncoder().encode(jsonText)
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
|
||||
setHeaders(headerView, jsonData.buffer)
|
||||
setTypeHeader(headerView, magicNumber)
|
||||
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer)
|
||||
const chunkHeader = createJSONChunk(jsonData.buffer, chunkType, chunkLength)
|
||||
|
||||
const fileContent = new Uint8Array(
|
||||
header.byteLength + chunkHeader.byteLength + jsonData.byteLength
|
||||
@@ -130,7 +150,9 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toBeDefined()
|
||||
expect(metadata.prompt).toBeDefined()
|
||||
|
||||
const prompt = metadata.prompt as Record<string, any>
|
||||
const prompt: {
|
||||
node1: { class_type: string; inputs: { seed: number } }
|
||||
} = fromAny(metadata.prompt)
|
||||
expect(prompt.node1.class_type).toBe('TestNode')
|
||||
expect(prompt.node1.inputs.seed).toBe(123456)
|
||||
})
|
||||
@@ -179,6 +201,74 @@ describe('GLTF binary metadata parser', () => {
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB magic number is invalid', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { magicNumber: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the GLB is missing the first chunk header', async () => {
|
||||
const { header, headerView } = createGLTFFileStructure()
|
||||
setTypeHeader(headerView, ASCII.GLTF)
|
||||
setVersionHeader(headerView, 2)
|
||||
setTotalLengthHeader(headerView, GltfSizeBytes.HEADER)
|
||||
const mockFile = new File([header], 'header-only.glb')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the first chunk is not JSON', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkType: 0 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the declared JSON chunk exceeds the buffer', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{}', { chunkLength: 1024 })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the JSON chunk cannot be parsed', async () => {
|
||||
const mockFile = createMockGltfFileFromText('{not json')
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when asset extras are missing', async () => {
|
||||
const mockFile = createMockGltfFile({ asset: { version: '2.0' } })
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
it('skips string metadata values that are not valid JSON', async () => {
|
||||
const mockFile = createMockGltfFile({
|
||||
asset: {
|
||||
version: '2.0',
|
||||
extras: {
|
||||
prompt: '{not json',
|
||||
workflow: '{not json'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const metadata = await getGltfBinaryMetadata(mockFile)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -193,5 +283,15 @@ describe('GLTF binary metadata parser', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load result is missing', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader result is not a buffer', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', 'not a buffer')
|
||||
expect(await getGltfBinaryMetadata(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,7 +7,8 @@ import {
|
||||
EXPECTED_PROMPT_NAN_COERCED,
|
||||
EXPECTED_WORKFLOW,
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromIsobmffFile } from './isobmff'
|
||||
|
||||
@@ -16,6 +17,82 @@ const nanFixturePath = path.resolve(
|
||||
__dirname,
|
||||
'__fixtures__/with_nan_metadata.mp4'
|
||||
)
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
function uint32(value: number) {
|
||||
return new Uint8Array([
|
||||
(value >>> 24) & 0xff,
|
||||
(value >>> 16) & 0xff,
|
||||
(value >>> 8) & 0xff,
|
||||
value & 0xff
|
||||
])
|
||||
}
|
||||
|
||||
function concatBytes(...chunks: Uint8Array[]) {
|
||||
const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0)
|
||||
const result = new Uint8Array(totalLength)
|
||||
let offset = 0
|
||||
|
||||
for (const chunk of chunks) {
|
||||
result.set(chunk, offset)
|
||||
offset += chunk.length
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
function box(type: string, payload = new Uint8Array(), size?: number) {
|
||||
return concatBytes(
|
||||
uint32(size ?? 8 + payload.length),
|
||||
encoder.encode(type),
|
||||
payload
|
||||
)
|
||||
}
|
||||
|
||||
function keyEntry(name: string) {
|
||||
const encoded = encoder.encode(name)
|
||||
return concatBytes(
|
||||
uint32(8 + encoded.length),
|
||||
encoder.encode('mdta'),
|
||||
encoded
|
||||
)
|
||||
}
|
||||
|
||||
function keysBox(names: string[]) {
|
||||
return box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(names.length), ...names.map(keyEntry))
|
||||
)
|
||||
}
|
||||
|
||||
function dataBox(value: string | Uint8Array) {
|
||||
const payload = typeof value === 'string' ? encoder.encode(value) : value
|
||||
return box('data', concatBytes(uint32(0), uint32(0), payload))
|
||||
}
|
||||
|
||||
function ilstItem(index: number, payload: Uint8Array) {
|
||||
return concatBytes(uint32(8 + payload.length), uint32(index), payload)
|
||||
}
|
||||
|
||||
function ilstBox(...items: Uint8Array[]) {
|
||||
return box('ilst', concatBytes(...items))
|
||||
}
|
||||
|
||||
function metaBox(...children: Uint8Array[]) {
|
||||
return box('meta', concatBytes(uint32(0), ...children))
|
||||
}
|
||||
|
||||
function udtaWithMeta(...children: Uint8Array[]) {
|
||||
return box('udta', metaBox(...children))
|
||||
}
|
||||
|
||||
async function readMp4(bytes: Uint8Array) {
|
||||
return getFromIsobmffFile(
|
||||
new File([bytes as Uint8Array<ArrayBuffer>], 'test.mp4', {
|
||||
type: 'video/mp4'
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
describe('ISOBMFF (MP4) metadata', () => {
|
||||
it('extracts workflow and prompt from QuickTime keys/ilst boxes', async () => {
|
||||
@@ -48,6 +125,102 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('extracts metadata from udta nested inside moov', async () => {
|
||||
const bytes = box(
|
||||
'moov',
|
||||
udtaWithMeta(
|
||||
keysBox(['WORKFLOW']),
|
||||
ilstBox(ilstItem(1, dataBox('xxxx{"nodes":[]}')))
|
||||
)
|
||||
)
|
||||
|
||||
const result = await readMp4(bytes)
|
||||
|
||||
expect(result.workflow).toEqual({ nodes: [] })
|
||||
})
|
||||
|
||||
it('returns empty when a top-level box declares an impossible size', async () => {
|
||||
const result = await readMp4(box('free', new Uint8Array([1, 2]), 100))
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when the keys box cannot provide entries', async () => {
|
||||
const tooShortKeys = box('keys', uint32(0))
|
||||
const missingKeyEntry = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(8))
|
||||
)
|
||||
const malformedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(7), encoder.encode('bad!'))
|
||||
)
|
||||
const oversizedKey = box(
|
||||
'keys',
|
||||
concatBytes(uint32(0), uint32(1), uint32(100), encoder.encode('bad!'))
|
||||
)
|
||||
|
||||
await expect(readMp4(udtaWithMeta(tooShortKeys))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(missingKeyEntry))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(malformedKey))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(oversizedKey))).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores item entries whose key is unknown or unsupported', async () => {
|
||||
const unknownIndex = udtaWithMeta(
|
||||
keysBox(['PROMPT']),
|
||||
ilstBox(ilstItem(2, dataBox('{"1":{}}')))
|
||||
)
|
||||
const unsupportedKey = udtaWithMeta(
|
||||
keysBox(['DESCRIPTION']),
|
||||
ilstBox(ilstItem(1, dataBox('{"ignored":true}')))
|
||||
)
|
||||
|
||||
await expect(readMp4(unknownIndex)).resolves.toEqual({})
|
||||
await expect(readMp4(unsupportedKey)).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('ignores metadata items without readable JSON data', async () => {
|
||||
const shortDataBox = box('data')
|
||||
const noJson = dataBox('not-json')
|
||||
const invalidJson = dataBox('prefix {not-json')
|
||||
const noDataBox = box('free', new Uint8Array([1]))
|
||||
const invalidItems = ilstBox(
|
||||
concatBytes(uint32(8), uint32(1)),
|
||||
concatBytes(uint32(100), uint32(1), invalidJson)
|
||||
)
|
||||
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, shortDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noJson))))
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, invalidJson)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(
|
||||
udtaWithMeta(keysBox(['PROMPT']), ilstBox(ilstItem(1, noDataBox)))
|
||||
)
|
||||
).resolves.toEqual({})
|
||||
await expect(
|
||||
readMp4(udtaWithMeta(keysBox(['PROMPT']), invalidItems))
|
||||
).resolves.toEqual({})
|
||||
})
|
||||
|
||||
it('returns empty when required metadata boxes are absent', async () => {
|
||||
await expect(readMp4(box('udta', box('free')))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(ilstBox()))).resolves.toEqual({})
|
||||
await expect(readMp4(udtaWithMeta(keysBox(['PROMPT'])))).resolves.toEqual(
|
||||
{}
|
||||
)
|
||||
})
|
||||
|
||||
describe('FileReader failure modes', () => {
|
||||
afterEach(() => vi.restoreAllMocks())
|
||||
|
||||
@@ -63,5 +236,10 @@ describe('ISOBMFF (MP4) metadata', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
|
||||
it('resolves empty when the FileReader load has no result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
expect(await getFromIsobmffFile(file)).toEqual({})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
127
src/scripts/metadata/parser.test.ts
Normal file
127
src/scripts/metadata/parser.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getFromWebmFile } from '@/scripts/metadata/ebml'
|
||||
import { getGltfBinaryMetadata } from '@/scripts/metadata/gltf'
|
||||
import { getFromIsobmffFile } from '@/scripts/metadata/isobmff'
|
||||
import { getDataFromJSON } from '@/scripts/metadata/json'
|
||||
import { getMp3Metadata } from '@/scripts/metadata/mp3'
|
||||
import { getOggMetadata } from '@/scripts/metadata/ogg'
|
||||
import { getWorkflowDataFromFile } from '@/scripts/metadata/parser'
|
||||
import { getSvgMetadata } from '@/scripts/metadata/svg'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
} from '@/scripts/pnginfo'
|
||||
|
||||
vi.mock('@/scripts/metadata/ebml', () => ({ getFromWebmFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/gltf', () => ({ getGltfBinaryMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/isobmff', () => ({ getFromIsobmffFile: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/json', () => ({ getDataFromJSON: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/mp3', () => ({ getMp3Metadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/ogg', () => ({ getOggMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/metadata/svg', () => ({ getSvgMetadata: vi.fn() }))
|
||||
vi.mock('@/scripts/pnginfo', () => ({
|
||||
getAvifMetadata: vi.fn(),
|
||||
getFlacMetadata: vi.fn(),
|
||||
getLatentMetadata: vi.fn(),
|
||||
getPngMetadata: vi.fn(),
|
||||
getWebpMetadata: vi.fn()
|
||||
}))
|
||||
|
||||
function file(type: string, name = 'file') {
|
||||
return new File(['data'], name, { type })
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('getWorkflowDataFromFile', () => {
|
||||
it('routes png/avif/mp3/ogg/webm to their parsers and returns the result', async () => {
|
||||
vi.mocked(getPngMetadata).mockResolvedValue({ a: 1 } as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/png'))).toEqual({ a: 1 })
|
||||
expect(getPngMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('image/avif'))
|
||||
expect(getAvifMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/mpeg'))
|
||||
expect(getMp3Metadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('audio/ogg'))
|
||||
expect(getOggMetadata).toHaveBeenCalled()
|
||||
|
||||
await getWorkflowDataFromFile(file('video/webm'))
|
||||
expect(getFromWebmFile).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('extracts workflow/prompt from webp, preferring lowercase keys', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'wf',
|
||||
prompt: 'pr'
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to capitalized webp keys when lowercase are absent', async () => {
|
||||
vi.mocked(getWebpMetadata).mockResolvedValue({
|
||||
Workflow: 'WF',
|
||||
Prompt: 'PR'
|
||||
} as never)
|
||||
expect(await getWorkflowDataFromFile(file('image/webp'))).toEqual({
|
||||
workflow: 'WF',
|
||||
prompt: 'PR'
|
||||
})
|
||||
})
|
||||
|
||||
it('handles both flac mime types and extracts workflow/prompt', async () => {
|
||||
vi.mocked(getFlacMetadata).mockResolvedValue({ workflow: 'w' } as never)
|
||||
expect(await getWorkflowDataFromFile(file('audio/flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
expect(await getWorkflowDataFromFile(file('audio/x-flac'))).toEqual({
|
||||
workflow: 'w',
|
||||
prompt: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('routes isobmff by mime type and by file extension', async () => {
|
||||
await getWorkflowDataFromFile(file('video/mp4'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.mov'))
|
||||
await getWorkflowDataFromFile(file('', 'clip.m4v'))
|
||||
expect(getFromIsobmffFile).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('routes svg and gltf by mime type or extension', async () => {
|
||||
await getWorkflowDataFromFile(file('image/svg+xml'))
|
||||
await getWorkflowDataFromFile(file('', 'icon.svg'))
|
||||
expect(getSvgMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('model/gltf-binary'))
|
||||
await getWorkflowDataFromFile(file('', 'model.glb'))
|
||||
expect(getGltfBinaryMetadata).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('routes latent/safetensors and json by extension or mime type', async () => {
|
||||
await getWorkflowDataFromFile(file('', 'x.latent'))
|
||||
await getWorkflowDataFromFile(file('', 'x.safetensors'))
|
||||
expect(getLatentMetadata).toHaveBeenCalledTimes(2)
|
||||
|
||||
await getWorkflowDataFromFile(file('application/json'))
|
||||
await getWorkflowDataFromFile(file('', 'x.json'))
|
||||
expect(getDataFromJSON).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('returns undefined for an unrecognized file', async () => {
|
||||
expect(
|
||||
await getWorkflowDataFromFile(file('application/zip', 'a.zip'))
|
||||
).toBe(undefined)
|
||||
})
|
||||
})
|
||||
@@ -2,7 +2,8 @@ import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
mockFileReaderAbort,
|
||||
mockFileReaderError
|
||||
mockFileReaderError,
|
||||
mockFileReaderResult
|
||||
} from './__fixtures__/helpers'
|
||||
import { getFromPngBuffer, getFromPngFile } from './png'
|
||||
|
||||
@@ -191,6 +192,27 @@ describe('getFromPngBuffer', () => {
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
expect(result['workflow']).toBe(workflow)
|
||||
})
|
||||
|
||||
it('logs error and skips compressed iTXt with invalid deflate data', async () => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const buffer = createPngWithChunk(
|
||||
'iTXt',
|
||||
'workflow',
|
||||
new Uint8Array([1, 2, 3]),
|
||||
{
|
||||
compressionFlag: 1,
|
||||
compressionMethod: 0
|
||||
}
|
||||
)
|
||||
|
||||
const result = await getFromPngBuffer(buffer)
|
||||
|
||||
expect(result['workflow']).toBeUndefined()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to decompress iTXt chunk "workflow"'),
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFromPngFile', () => {
|
||||
@@ -228,5 +250,12 @@ describe('getFromPngFile', () => {
|
||||
mockFileReaderAbort('readAsArrayBuffer')
|
||||
await expect(getFromPngFile(file)).rejects.toThrow('FileReader aborted')
|
||||
})
|
||||
|
||||
it('rejects when the FileReader load has no ArrayBuffer result', async () => {
|
||||
mockFileReaderResult('readAsArrayBuffer', null)
|
||||
await expect(getFromPngFile(file)).rejects.toThrow(
|
||||
'Failed to read file as ArrayBuffer'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
import type { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { api } from '@/scripts/api'
|
||||
import { getFromAvifFile } from './metadata/avif'
|
||||
import { getFromFlacFile } from './metadata/flac'
|
||||
import { getFromPngFile } from './metadata/png'
|
||||
import {
|
||||
getAvifMetadata,
|
||||
getFlacMetadata,
|
||||
importA1111,
|
||||
getLatentMetadata,
|
||||
getPngMetadata,
|
||||
getWebpMetadata
|
||||
@@ -56,6 +62,20 @@ function encodeAsciiIfd(entries: AsciiIfdEntry[]): Uint8Array {
|
||||
return buf
|
||||
}
|
||||
|
||||
function encodeNonAsciiIfdEntry(tag: number): Uint8Array {
|
||||
const buf = new Uint8Array(22)
|
||||
const dv = new DataView(buf.buffer)
|
||||
buf.set([0x49, 0x49], 0)
|
||||
dv.setUint16(2, 0x002a, true)
|
||||
dv.setUint32(4, 8, true)
|
||||
dv.setUint16(8, 1, true)
|
||||
dv.setUint16(10, tag, true)
|
||||
dv.setUint16(12, 3, true)
|
||||
dv.setUint32(14, 1, true)
|
||||
dv.setUint32(18, 123, true)
|
||||
return buf
|
||||
}
|
||||
|
||||
type WebpChunk = { type: string; payload: Uint8Array }
|
||||
|
||||
function wrapInWebp(chunks: WebpChunk[]): File {
|
||||
@@ -157,6 +177,16 @@ describe('getWebpMetadata', () => {
|
||||
|
||||
expect(metadata).toEqual({ workflow: '{"a":1}' })
|
||||
})
|
||||
|
||||
it('ignores EXIF entries that are not ASCII strings', async () => {
|
||||
const file = wrapInWebp([
|
||||
{ type: 'EXIF', payload: encodeNonAsciiIfdEntry(270) }
|
||||
])
|
||||
|
||||
const metadata = await getWebpMetadata(file)
|
||||
|
||||
expect(metadata).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getLatentMetadata', () => {
|
||||
@@ -234,3 +264,313 @@ describe('format-specific metadata wrappers', () => {
|
||||
expect(result).toEqual({ workflow: '{"avif":1}' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('importA1111', () => {
|
||||
function widget(
|
||||
name: string,
|
||||
options: string[] = []
|
||||
): IBaseWidget & { value?: string | number } {
|
||||
return fromPartial<IBaseWidget & { value?: string | number }>({
|
||||
name,
|
||||
options: { values: options },
|
||||
value: undefined
|
||||
})
|
||||
}
|
||||
|
||||
function createNode(type: string): LGraphNode {
|
||||
const widgetsByType: Record<string, IBaseWidget[]> = {
|
||||
CheckpointLoaderSimple: [widget('ckpt_name', ['sd15.safetensors'])],
|
||||
CLIPSetLastLayer: [widget('stop_at_clip_layer')],
|
||||
CLIPTextEncode: [widget('text')],
|
||||
EmptyLatentImage: [widget('width'), widget('height')],
|
||||
ImageScale: [widget('width'), widget('height')],
|
||||
ImageUpscaleWithModel: [],
|
||||
KSampler: [
|
||||
widget('cfg'),
|
||||
widget('sampler_name', ['euler_a', 'dpmpp_2m']),
|
||||
widget('scheduler', ['normal', 'karras']),
|
||||
widget('seed'),
|
||||
widget('steps'),
|
||||
widget('denoise')
|
||||
],
|
||||
LatentUpscale: [
|
||||
widget('upscale_method', ['nearest-exact']),
|
||||
widget('width'),
|
||||
widget('height')
|
||||
],
|
||||
LoraLoader: [
|
||||
widget('lora_name', ['foo.safetensors']),
|
||||
widget('strength_model'),
|
||||
widget('strength_clip')
|
||||
],
|
||||
UpscaleModelLoader: [widget('model_name', ['ESRGAN'])],
|
||||
VAEEncodeTiled: [],
|
||||
VAEDecodeTiled: [],
|
||||
SaveImage: [],
|
||||
VAEDecode: []
|
||||
}
|
||||
return {
|
||||
type,
|
||||
widgets: widgetsByType[type] ?? [],
|
||||
connect: vi.fn()
|
||||
} as unknown as LGraphNode
|
||||
}
|
||||
|
||||
function createGraph(): LGraph {
|
||||
return fromPartial<LGraph>({
|
||||
add: vi.fn(),
|
||||
arrange: vi.fn(),
|
||||
clear: vi.fn()
|
||||
})
|
||||
}
|
||||
|
||||
function findWidget(node: LGraphNode, name: string) {
|
||||
return node.widgets?.find((widget) => widget.name === name)
|
||||
}
|
||||
|
||||
it('ignores text without parsed generation settings', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(graph, 'positive prompt only')
|
||||
await importA1111(graph, 'positive prompt\nSteps:\n')
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('ignores text without a negative prompt section', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
['positive prompt only', 'Steps: 20, Sampler: Euler a'].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('stops when a required base node cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) =>
|
||||
type === 'KSampler' ? null : createNode(type)
|
||||
)
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'prompt',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(graph.clear).not.toHaveBeenCalled()
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Failed to create required nodes for A1111 import'
|
||||
)
|
||||
})
|
||||
|
||||
it('builds a basic graph from A1111 parameters', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue(['easynegative'])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:foo:0.7> portrait easynegative',
|
||||
'Negative prompt: blurry <lora:bad:not-number>',
|
||||
'Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 42, Size: 512x512, Model: sd15, Clip skip: 2, Model hash: ignored'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
const clipSkip = nodes.find((node) => node.type === 'CLIPSetLastLayer')
|
||||
const sampler = nodes.find((node) => node.type === 'KSampler')
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const lora = nodes.find((node) => node.type === 'LoraLoader')
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
|
||||
expect(graph.clear).toHaveBeenCalledOnce()
|
||||
expect(graph.arrange).toHaveBeenCalledOnce()
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('sd15.safetensors')
|
||||
expect(findWidget(clipSkip!, 'stop_at_clip_layer')?.value).toBe(-2)
|
||||
expect(findWidget(sampler!, 'cfg')?.value).toBe(7)
|
||||
expect(findWidget(sampler!, 'sampler_name')?.value).toBe('euler_a')
|
||||
expect(findWidget(sampler!, 'scheduler')?.value).toBe('normal')
|
||||
expect(findWidget(sampler!, 'seed')?.value).toBe(42)
|
||||
expect(findWidget(sampler!, 'steps')?.value).toBe(20)
|
||||
expect(findWidget(image!, 'width')?.value).toBe(512)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(512)
|
||||
expect(findWidget(lora!, 'lora_name')?.value).toBe('foo.safetensors')
|
||||
expect(findWidget(lora!, 'strength_model')?.value).toBe(0.7)
|
||||
expect(findWidget(lora!, 'strength_clip')?.value).toBe(0.7)
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(
|
||||
' portrait embedding:easynegative'
|
||||
)
|
||||
expect(findWidget(textNodes[1], 'text')?.value).toBe('blurry ')
|
||||
})
|
||||
|
||||
it('keeps unknown option-prefix values and logs the mismatch', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Unknown Sampler, CFG scale: 7, Seed: 42, Size: 512x512, Model: unknown-model'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const checkpoint = nodes.find(
|
||||
(node) => node.type === 'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(findWidget(checkpoint!, 'ckpt_name')?.value).toBe('unknown-model')
|
||||
expect(console.warn).toHaveBeenCalledWith(
|
||||
"Unknown value 'unknown-model' for widget 'ckpt_name'",
|
||||
checkpoint
|
||||
)
|
||||
})
|
||||
|
||||
it('skips missing LoraLoader nodes while keeping prompt text cleaned', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LoraLoader') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'<lora:missing:0.5> portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 20, Sampler: Euler a, Size: 512x512'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const textNodes = nodes.filter((node) => node.type === 'CLIPTextEncode')
|
||||
expect(findWidget(textNodes[0], 'text')?.value).toBe(' portrait')
|
||||
})
|
||||
|
||||
it('returns from latent hires setup when LatentUpscale cannot be created', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
if (type === 'LatentUpscale') return null
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, Size: 512x512, Hires upscale: 2, Hires upscaler: Latent'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
expect(nodes.some((node) => node.type === 'KSampler')).toBe(true)
|
||||
expect(nodes.some((node) => node.type === 'LatentUpscale')).toBe(false)
|
||||
})
|
||||
|
||||
it('builds a latent hires pass with explicit resize and denoise settings', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 12, Sampler: DPM++ 2M Karras, CFG scale: 5, Seed: 1, Size: 513x577, Model: sd15, Hires resize: 1025x1089, Hires steps: 4, Hires upscaler: Latent (nearest-exact), Denoising strength: 0.35'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const image = nodes.find((node) => node.type === 'EmptyLatentImage')
|
||||
const latentUpscale = nodes.find((node) => node.type === 'LatentUpscale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(image!, 'width')?.value).toBe(576)
|
||||
expect(findWidget(image!, 'height')?.value).toBe(640)
|
||||
expect(findWidget(latentUpscale!, 'upscale_method')?.value).toBe(
|
||||
'nearest-exact'
|
||||
)
|
||||
expect(findWidget(latentUpscale!, 'width')?.value).toBe(1088)
|
||||
expect(findWidget(latentUpscale!, 'height')?.value).toBe(1152)
|
||||
expect(findWidget(samplers[0], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[0], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(4)
|
||||
expect(findWidget(samplers[1], 'cfg')?.value).toBe(5)
|
||||
expect(findWidget(samplers[1], 'scheduler')?.value).toBe('karras')
|
||||
expect(findWidget(samplers[1], 'sampler_name')?.value).toBe('dpmpp_2m')
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(0.35)
|
||||
})
|
||||
|
||||
it('builds an image upscaler hires pass with fallback steps and denoise', async () => {
|
||||
const graph = createGraph()
|
||||
const nodes: LGraphNode[] = []
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
vi.spyOn(api, 'getEmbeddings').mockResolvedValue([])
|
||||
vi.spyOn(LiteGraph, 'createNode').mockImplementation((type) => {
|
||||
const node = createNode(type)
|
||||
nodes.push(node)
|
||||
return node
|
||||
})
|
||||
|
||||
await importA1111(
|
||||
graph,
|
||||
[
|
||||
'portrait',
|
||||
'Negative prompt: blurry',
|
||||
'Steps: 8, Sampler: Euler a, CFG scale: 6, Seed: 2, Size: 512x512, Model: sd15, Hires upscale: 1.5, Hires upscaler: ESRGAN'
|
||||
].join('\n')
|
||||
)
|
||||
|
||||
const upscaleLoader = nodes.find(
|
||||
(node) => node.type === 'UpscaleModelLoader'
|
||||
)
|
||||
const imageScale = nodes.find((node) => node.type === 'ImageScale')
|
||||
const samplers = nodes.filter((node) => node.type === 'KSampler')
|
||||
|
||||
expect(findWidget(upscaleLoader!, 'model_name')?.value).toBe('ESRGAN')
|
||||
expect(findWidget(imageScale!, 'width')?.value).toBe(768)
|
||||
expect(findWidget(imageScale!, 'height')?.value).toBe(768)
|
||||
expect(findWidget(samplers[1], 'steps')?.value).toBe(8)
|
||||
expect(findWidget(samplers[1], 'denoise')?.value).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
450
src/scripts/ui.test.ts
Normal file
450
src/scripts/ui.test.ts
Normal file
@@ -0,0 +1,450 @@
|
||||
import { fromAny, fromPartial } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { mockSettingStore } = vi.hoisted(() => ({
|
||||
mockSettingStore: {
|
||||
get: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useRunButtonTelemetry', () => ({
|
||||
useRunButtonTelemetry: () => ({ trackRunButton: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/remote/comfyui/jobs/fetchJobs', () => ({
|
||||
extractWorkflow: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/composables/useSettingsDialog', () => ({
|
||||
useSettingsDialog: () => ({ show: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackWorkflowExecution: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ resetView: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({ execute: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspaceStore', () => ({
|
||||
useWorkspaceStore: () => ({ focusMode: false })
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
deleteItem: vi.fn(),
|
||||
dispatchCustomEvent: vi.fn(),
|
||||
getHistory: vi.fn(),
|
||||
getJobDetail: vi.fn(),
|
||||
getQueue: vi.fn(),
|
||||
interrupt: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./app', () => ({
|
||||
app: {
|
||||
clean: vi.fn(),
|
||||
handleFile: vi.fn(),
|
||||
loadGraphData: vi.fn(),
|
||||
openClipspace: vi.fn(),
|
||||
queuePrompt: vi.fn(),
|
||||
refreshComboInNodes: vi.fn(() => Promise.resolve())
|
||||
},
|
||||
ComfyApp: class ComfyApp {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/dialog', () => ({
|
||||
ComfyDialog: class ComfyDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/settings', () => ({
|
||||
ComfySettingsDialog: class ComfySettingsDialog {}
|
||||
}))
|
||||
|
||||
vi.mock('./ui/toggleSwitch', () => ({
|
||||
toggleSwitch: () => document.createElement('div')
|
||||
}))
|
||||
|
||||
import { extractWorkflow } from '@/platform/remote/comfyui/jobs/fetchJobs'
|
||||
|
||||
import { api } from './api'
|
||||
import { app } from './app'
|
||||
import { $el, ComfyUI } from './ui'
|
||||
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
localStorage.clear()
|
||||
vi.clearAllMocks()
|
||||
mockSettingStore.get.mockReturnValue(undefined)
|
||||
Object.assign(app, {
|
||||
lastExecutionError: undefined,
|
||||
nodeOutputs: undefined
|
||||
})
|
||||
})
|
||||
|
||||
async function click(button: HTMLButtonElement) {
|
||||
const handler = button.onclick
|
||||
expect(handler).toBeTypeOf('function')
|
||||
await Promise.resolve(handler?.call(button, fromAny(new MouseEvent('click'))))
|
||||
}
|
||||
|
||||
function buttonByText(root: ParentNode, text: string): HTMLButtonElement {
|
||||
const button = [...root.querySelectorAll('button')].find(
|
||||
(candidate) => candidate.textContent === text
|
||||
)
|
||||
if (!button) throw new Error(`Missing button: ${text}`)
|
||||
return button
|
||||
}
|
||||
|
||||
describe('$el', () => {
|
||||
it('creates elements with classes, children, props, and callbacks', () => {
|
||||
const parent = document.createElement('section')
|
||||
const child = document.createElement('span')
|
||||
const callback = vi.fn()
|
||||
|
||||
const element = $el(
|
||||
'label.primary.secondary',
|
||||
{
|
||||
parent,
|
||||
$: callback,
|
||||
dataset: { role: 'name' },
|
||||
for: 'target-input',
|
||||
style: { display: 'block' },
|
||||
title: 'Label'
|
||||
},
|
||||
child
|
||||
)
|
||||
|
||||
expect(element.tagName).toBe('LABEL')
|
||||
expect(element.classList.contains('primary')).toBe(true)
|
||||
expect(element.classList.contains('secondary')).toBe(true)
|
||||
expect(element.dataset.role).toBe('name')
|
||||
expect(element.getAttribute('for')).toBe('target-input')
|
||||
expect(element.style.display).toBe('block')
|
||||
expect(element.title).toBe('Label')
|
||||
expect(element.firstElementChild).toBe(child)
|
||||
expect(parent.firstElementChild).toBe(element)
|
||||
expect(callback).toHaveBeenCalledWith(element)
|
||||
})
|
||||
|
||||
it('accepts string and single-element children shorthands', () => {
|
||||
const textElement = $el('button', 'Run')
|
||||
const child = document.createElement('strong')
|
||||
const wrapper = $el('div', child)
|
||||
|
||||
expect(textElement.textContent).toBe('Run')
|
||||
expect(wrapper.firstElementChild).toBe(child)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ComfyUI legacy menu', () => {
|
||||
it('loads queue items and runs list actions', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: [{ id: 'pending', priority: 2 }]
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({
|
||||
outputs: { node: { images: ['image.png'] } }
|
||||
} as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(api.getJobDetail).toHaveBeenCalledWith('running')
|
||||
expect(extractWorkflow).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toEqual({ node: { images: ['image.png'] } })
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Cancel'))
|
||||
expect(api.interrupt).toHaveBeenCalledWith('running')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Delete'))
|
||||
expect(api.deleteItem).toHaveBeenCalledWith('queue', 'pending')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Clear Queue'))
|
||||
expect(api.clearItems).toHaveBeenCalledWith('queue')
|
||||
|
||||
await click(buttonByText(ui.queue.element, 'Refresh'))
|
||||
expect(api.getQueue).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips loading queue items when job details are unavailable', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue(null as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(extractWorkflow).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('loads queue item workflows without outputs', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [{ id: 'running', priority: 1 }],
|
||||
Pending: []
|
||||
} as never)
|
||||
vi.mocked(api.getJobDetail).mockResolvedValue({ id: 'running' } as never)
|
||||
vi.mocked(extractWorkflow).mockResolvedValue({ nodes: [] } as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.queue.show()
|
||||
await click(buttonByText(ui.queue.element, 'Load'))
|
||||
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith({ nodes: [] }, true, false)
|
||||
expect(app.nodeOutputs).toBeUndefined()
|
||||
})
|
||||
|
||||
it('loads history in reverse order', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValue([
|
||||
{ id: 'old', priority: 1 },
|
||||
{ id: 'new', priority: 2 }
|
||||
] as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await ui.history.show()
|
||||
|
||||
expect(ui.history.element.textContent).toContain('2: LoadDelete')
|
||||
expect(ui.history.element.textContent?.indexOf('2:')).toBeLessThan(
|
||||
ui.history.element.textContent?.indexOf('1:') ?? Number.MAX_SAFE_INTEGER
|
||||
)
|
||||
})
|
||||
|
||||
it('updates queue status and auto-queues when enabled', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
ui.batchCount = 3
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: 0')
|
||||
expect(ui.lastQueueSize).toBe(3)
|
||||
|
||||
ui.setStatus(null)
|
||||
expect(ui.queueSize.textContent).toBe('Queue size: ERR')
|
||||
})
|
||||
|
||||
it('does not auto-queue while a prior execution error is present', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'instant'
|
||||
ui.lastQueueSize = 1
|
||||
Object.assign(app, { lastExecutionError: new Error('failed') })
|
||||
|
||||
ui.setStatus({ exec_info: { queue_remaining: 0 } })
|
||||
|
||||
expect(app.queuePrompt).not.toHaveBeenCalled()
|
||||
expect(ui.lastQueueSize).toBe(0)
|
||||
})
|
||||
|
||||
it('tracks graph changes for change-mode auto queueing', () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const graphChanged = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find(([eventName]) => eventName === 'graphChanged')?.[1]
|
||||
if (!graphChanged) throw new Error('Missing graphChanged listener')
|
||||
|
||||
ui.autoQueueEnabled = true
|
||||
ui.autoQueueMode = 'change'
|
||||
ui.lastQueueSize = 1
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(ui.graphHasChanged).toBe(true)
|
||||
|
||||
ui.lastQueueSize = 0
|
||||
graphChanged(fromAny(new CustomEvent('graphChanged')))
|
||||
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
expect(ui.graphHasChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('wires primary menu buttons to app and command actions', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Queue Prompt'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(0, 1)
|
||||
|
||||
await click(buttonByText(document, 'Queue Front'))
|
||||
expect(app.queuePrompt).toHaveBeenCalledWith(-1, 1)
|
||||
|
||||
await click(buttonByText(document, 'Save'))
|
||||
await click(buttonByText(document, 'Save (API Format)'))
|
||||
await click(buttonByText(document, 'Refresh'))
|
||||
await click(buttonByText(document, 'Clipspace'))
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
await click(buttonByText(document, 'Reset View'))
|
||||
|
||||
expect(app.refreshComboInNodes).toHaveBeenCalled()
|
||||
expect(app.openClipspace).toHaveBeenCalled()
|
||||
expect(app.clean).toHaveBeenCalled()
|
||||
expect(app.loadGraphData).toHaveBeenCalledWith()
|
||||
expect(ui.menuContainer.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('wires file input and legacy option controls', async () => {
|
||||
const ui = new ComfyUI(app)
|
||||
const file = new File(['{}'], 'workflow.json', {
|
||||
type: 'application/json'
|
||||
})
|
||||
const fileInput = document.getElementById(
|
||||
'comfy-file-input'
|
||||
) as HTMLInputElement
|
||||
Object.defineProperty(fileInput, 'files', {
|
||||
configurable: true,
|
||||
value: [file]
|
||||
})
|
||||
|
||||
await Promise.resolve(
|
||||
fileInput.onchange?.call(fileInput, new Event('change'))
|
||||
)
|
||||
expect(app.handleFile).toHaveBeenCalledWith(file, 'file_button')
|
||||
expect(fileInput.value).toBe('')
|
||||
|
||||
const range = document.getElementById(
|
||||
'batchCountInputRange'
|
||||
) as HTMLInputElement
|
||||
const number = document.getElementById(
|
||||
'batchCountInputNumber'
|
||||
) as HTMLInputElement
|
||||
range.value = '4'
|
||||
const extraOptionsCheckbox = document.querySelector(
|
||||
'label input[type="checkbox"]'
|
||||
) as HTMLInputElement
|
||||
extraOptionsCheckbox.checked = true
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(4)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('block')
|
||||
|
||||
number.value = '7'
|
||||
number.oninput?.call(number, fromPartial({ target: number }))
|
||||
expect(range.value).toBe('7')
|
||||
|
||||
range.value = '9'
|
||||
range.oninput?.call(range, fromPartial({ srcElement: range }))
|
||||
expect(number.value).toBe('9')
|
||||
|
||||
const autoQueueCheckbox = document.getElementById(
|
||||
'autoQueueCheckbox'
|
||||
) as HTMLInputElement
|
||||
autoQueueCheckbox.checked = true
|
||||
autoQueueCheckbox.onchange?.call(
|
||||
autoQueueCheckbox,
|
||||
fromPartial({ target: autoQueueCheckbox })
|
||||
)
|
||||
expect(ui.autoQueueEnabled).toBe(true)
|
||||
|
||||
extraOptionsCheckbox.checked = false
|
||||
extraOptionsCheckbox.onchange?.call(
|
||||
extraOptionsCheckbox,
|
||||
fromPartial({ srcElement: extraOptionsCheckbox })
|
||||
)
|
||||
expect(ui.batchCount).toBe(1)
|
||||
expect(ui.autoQueueEnabled).toBe(false)
|
||||
expect(document.getElementById('extraOptions')?.style.display).toBe('none')
|
||||
})
|
||||
|
||||
it('toggles queue visibility through the menu button', async () => {
|
||||
vi.mocked(api.getQueue).mockResolvedValue({
|
||||
Running: [],
|
||||
Pending: []
|
||||
} as never)
|
||||
const ui = new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'View Queue'))
|
||||
expect(ui.queue.visible).toBe(true)
|
||||
|
||||
await click(buttonByText(document, 'Close'))
|
||||
expect(ui.queue.visible).toBe(false)
|
||||
})
|
||||
|
||||
it('does not clear or load defaults when confirmation is declined', async () => {
|
||||
mockSettingStore.get.mockReturnValue(true)
|
||||
vi.stubGlobal(
|
||||
'confirm',
|
||||
vi.fn(() => false)
|
||||
)
|
||||
try {
|
||||
new ComfyUI(app)
|
||||
|
||||
await click(buttonByText(document, 'Clear'))
|
||||
await click(buttonByText(document, 'Load Default'))
|
||||
|
||||
expect(app.clean).not.toHaveBeenCalled()
|
||||
expect(app.loadGraphData).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
vi.unstubAllGlobals()
|
||||
}
|
||||
})
|
||||
|
||||
it('persists manual menu dragging', () => {
|
||||
Object.defineProperty(document.body, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 1000
|
||||
})
|
||||
Object.defineProperty(document.body, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 800
|
||||
})
|
||||
const ui = new ComfyUI(app)
|
||||
ui.menuContainer.style.display = 'block'
|
||||
Object.defineProperty(ui.menuContainer, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 100
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: 80
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetLeft', {
|
||||
configurable: true,
|
||||
get: () => 700
|
||||
})
|
||||
Object.defineProperty(ui.menuContainer, 'offsetTop', {
|
||||
configurable: true,
|
||||
get: () => 20
|
||||
})
|
||||
const handle = ui.menuContainer.querySelector('.drag-handle') as HTMLElement
|
||||
|
||||
handle.onmousedown?.(
|
||||
new MouseEvent('mousedown', { clientX: 10, clientY: 10 })
|
||||
)
|
||||
document.onmousemove?.(
|
||||
new MouseEvent('mousemove', { clientX: 20, clientY: 30 })
|
||||
)
|
||||
document.onmouseup?.(new MouseEvent('mouseup'))
|
||||
|
||||
expect(ui.menuContainer.classList.contains('comfy-menu-manual-pos')).toBe(
|
||||
true
|
||||
)
|
||||
expect(ui.menuContainer.style.right).toBe('190px')
|
||||
expect(localStorage.getItem('Comfy.MenuPosition')).toBe(
|
||||
JSON.stringify({ x: 700, y: 20 })
|
||||
)
|
||||
expect(document.onmousemove).toBeNull()
|
||||
expect(document.onmouseup).toBeNull()
|
||||
})
|
||||
})
|
||||
199
src/scripts/ui/components/button.test.ts
Normal file
199
src/scripts/ui/components/button.test.ts
Normal file
@@ -0,0 +1,199 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyApp } from '@/scripts/app'
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, props?: Record<string, unknown>, children?: Node[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
if (props) {
|
||||
const listeners = Object.entries(props).filter(([key]) =>
|
||||
key.startsWith('on')
|
||||
)
|
||||
for (const [key, listener] of listeners) {
|
||||
if (typeof listener === 'function') {
|
||||
element.addEventListener(key.slice(2), listener as EventListener)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (children) element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyButton } from './button'
|
||||
|
||||
class MockPopup extends EventTarget {
|
||||
element = document.createElement('div')
|
||||
open = false
|
||||
toggle = vi.fn(() => {
|
||||
this.open = !this.open
|
||||
this.dispatchEvent(new CustomEvent('change'))
|
||||
})
|
||||
}
|
||||
|
||||
function mockApp(settingValue: boolean) {
|
||||
let listener: (() => void) | undefined
|
||||
const app = {
|
||||
ui: {
|
||||
settings: {
|
||||
getSettingValue: vi.fn(() => settingValue),
|
||||
addEventListener: vi.fn((_event: string, callback: () => void) => {
|
||||
listener = callback
|
||||
})
|
||||
}
|
||||
}
|
||||
} as unknown as ComfyApp
|
||||
return {
|
||||
app,
|
||||
setSettingValue(value: boolean) {
|
||||
settingValue = value
|
||||
listener?.()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
describe('ComfyButton', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('renders icon, content, tooltip, enabled state, and click action', () => {
|
||||
const action = vi.fn()
|
||||
const button = new ComfyButton({
|
||||
icon: 'play',
|
||||
overIcon: 'pause',
|
||||
iconSize: 18,
|
||||
content: 'Run',
|
||||
tooltip: 'Queue prompt',
|
||||
enabled: false,
|
||||
classList: { primary: true, hiddenClass: false },
|
||||
action
|
||||
})
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
expect(button.contentElement.textContent).toBe('Run')
|
||||
expect(button.element.title).toBe('Queue prompt')
|
||||
expect(button.element.getAttribute('aria-label')).toBe('Queue prompt')
|
||||
expect(button.element.classList.contains('primary')).toBe(true)
|
||||
expect(button.element.classList.contains('disabled')).toBe(true)
|
||||
expect((button.element as HTMLButtonElement).disabled).toBe(true)
|
||||
|
||||
button.enabled = true
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause mdi-18px')
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(button.iconElement.className).toBe('mdi mdi-play mdi-18px')
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
expect(action).toHaveBeenCalledWith(expect.any(MouseEvent), button)
|
||||
})
|
||||
|
||||
it('supports HTMLElement content and removing tooltip text', () => {
|
||||
const button = new ComfyButton({ content: 'Text', tooltip: 'Hint' })
|
||||
const content = document.createElement('strong')
|
||||
content.textContent = 'Element'
|
||||
|
||||
button.content = content
|
||||
button.tooltip = ''
|
||||
|
||||
expect(button.contentElement.firstElementChild).toBe(content)
|
||||
expect(button.element.hasAttribute('title')).toBe(false)
|
||||
})
|
||||
|
||||
it('updates the hover icon when overIcon changes while hovered', () => {
|
||||
const button = new ComfyButton({ icon: 'play' })
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.overIcon = 'pause'
|
||||
|
||||
expect(button.iconElement.className).toBe('mdi mdi-pause')
|
||||
})
|
||||
|
||||
it('hides and shows from a visibility setting', () => {
|
||||
const settings = mockApp(false)
|
||||
const button = new ComfyButton({
|
||||
app: settings.app,
|
||||
visibilitySetting: {
|
||||
id: 'Comfy.UseNewMenu',
|
||||
showValue: true
|
||||
}
|
||||
})
|
||||
|
||||
expect(button.hidden).toBe(true)
|
||||
expect(button.element.classList.contains('hidden')).toBe(true)
|
||||
|
||||
settings.setSettingValue(true)
|
||||
|
||||
expect(button.hidden).toBe(false)
|
||||
expect(button.element.classList.contains('hidden')).toBe(false)
|
||||
})
|
||||
|
||||
it('toggles click popups and reflects popup open state in classes', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(fromAny(popup))
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).toHaveBeenCalledOnce()
|
||||
expect(button.element.classList.contains('popup-open')).toBe(true)
|
||||
|
||||
popup.toggle()
|
||||
|
||||
expect(button.element.classList.contains('popup-closed')).toBe(true)
|
||||
})
|
||||
|
||||
it('opens hover popups while either the button or popup is hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(true)
|
||||
popup.element.dispatchEvent(new MouseEvent('mouseleave'))
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('does not click-toggle a hover popup while hovered', () => {
|
||||
const popup = new MockPopup()
|
||||
const button = new ComfyButton({ icon: 'dots' }).withPopup(
|
||||
fromAny(popup),
|
||||
'hover'
|
||||
)
|
||||
|
||||
button.element.dispatchEvent(new MouseEvent('mouseenter'))
|
||||
button.element.dispatchEvent(new MouseEvent('click'))
|
||||
|
||||
expect(popup.toggle).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
284
src/scripts/ui/components/popup.test.ts
Normal file
284
src/scripts/ui/components/popup.test.ts
Normal file
@@ -0,0 +1,284 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
type ElChild = Node | string
|
||||
type ElInput = Record<string, unknown> | ElChild | ElChild[]
|
||||
|
||||
function appendChildren(element: HTMLElement, children: ElChild | ElChild[]) {
|
||||
const list = Array.isArray(children) ? children : [children]
|
||||
for (const child of list) {
|
||||
element.append(child)
|
||||
}
|
||||
}
|
||||
|
||||
vi.mock('../../ui', () => ({
|
||||
$el: (tag: string, propsOrChildren?: ElInput, children?: ElChild[]) => {
|
||||
const [tagName, ...classes] = tag.split('.')
|
||||
const element = document.createElement(tagName)
|
||||
if (classes.length) element.classList.add(...classes)
|
||||
|
||||
if (
|
||||
propsOrChildren instanceof Node ||
|
||||
typeof propsOrChildren === 'string' ||
|
||||
Array.isArray(propsOrChildren)
|
||||
) {
|
||||
appendChildren(element, propsOrChildren)
|
||||
} else if (propsOrChildren) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else if (key === 'parent' && value instanceof HTMLElement) {
|
||||
value.append(element)
|
||||
} else if (key === 'textContent') {
|
||||
element.textContent = String(value)
|
||||
} else if (key === 'ariaLabel') {
|
||||
element.setAttribute('aria-label', String(value))
|
||||
} else if (key === 'ariaHasPopup') {
|
||||
element.setAttribute('aria-haspopup', String(value))
|
||||
} else if (key === 'type') {
|
||||
element.setAttribute('type', String(value))
|
||||
} else if (
|
||||
key.toLowerCase().startsWith('on') &&
|
||||
typeof value === 'function'
|
||||
) {
|
||||
element.addEventListener(
|
||||
key.slice(2).toLowerCase(),
|
||||
value as EventListener
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (children) appendChildren(element, children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
prop: <T>(
|
||||
target: object,
|
||||
name: string,
|
||||
defaultValue: T,
|
||||
onChanged?: (currentValue: T, previousValue: T) => void
|
||||
) => {
|
||||
let currentValue: T
|
||||
Object.defineProperty(target, name, {
|
||||
get() {
|
||||
return currentValue
|
||||
},
|
||||
set(newValue: T) {
|
||||
const previousValue = currentValue
|
||||
currentValue = newValue
|
||||
onChanged?.(currentValue, previousValue)
|
||||
}
|
||||
})
|
||||
;(target as Record<string, T>)[name] = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../utils', () => ({
|
||||
applyClasses: (
|
||||
element: HTMLElement,
|
||||
classList: string | Record<string, boolean>,
|
||||
...baseClasses: string[]
|
||||
) => {
|
||||
element.className = baseClasses.join(' ')
|
||||
if (typeof classList === 'string') {
|
||||
element.classList.add(...classList.split(' ').filter(Boolean))
|
||||
} else {
|
||||
for (const [className, enabled] of Object.entries(classList)) {
|
||||
element.classList.toggle(className, enabled)
|
||||
}
|
||||
}
|
||||
},
|
||||
toggleElement:
|
||||
<T>(
|
||||
element: HTMLElement,
|
||||
{
|
||||
onShow
|
||||
}: {
|
||||
onShow?: (element: HTMLElement, value: T) => void
|
||||
} = {}
|
||||
) =>
|
||||
(value: T) => {
|
||||
element.hidden = !value
|
||||
if (value) onShow?.(element, value)
|
||||
}
|
||||
}))
|
||||
|
||||
import { ComfyAsyncDialog } from './asyncDialog'
|
||||
import { ComfyButton } from './button'
|
||||
import { ComfyButtonGroup } from './buttonGroup'
|
||||
import { ComfyPopup } from './popup'
|
||||
import { ComfySplitButton } from './splitButton'
|
||||
|
||||
function targetWithRect(rect: DOMRect) {
|
||||
const target = document.createElement('button')
|
||||
vi.spyOn(target, 'getBoundingClientRect').mockReturnValue(rect)
|
||||
document.body.append(target)
|
||||
return target
|
||||
}
|
||||
|
||||
describe('ComfyPopup and related UI components', () => {
|
||||
beforeEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('opens, positions, updates children and classes, and closes from escape', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 10, y: 20, width: 80, height: 30 })
|
||||
)
|
||||
const child = document.createElement('span')
|
||||
const popup = new ComfyPopup(
|
||||
{ target, classList: { menu: true, hidden: false } },
|
||||
child
|
||||
)
|
||||
const open = vi.fn()
|
||||
const close = vi.fn()
|
||||
const change = vi.fn()
|
||||
popup.addEventListener('open', open)
|
||||
popup.addEventListener('close', close)
|
||||
popup.addEventListener('change', change)
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 20 })
|
||||
)
|
||||
|
||||
popup.open = true
|
||||
|
||||
expect(open).toHaveBeenCalledOnce()
|
||||
expect(change).toHaveBeenCalledOnce()
|
||||
expect(popup.element).toHaveClass('open')
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('10px')
|
||||
expect(popup.element.style.getPropertyValue('--bottom')).toBe('35px')
|
||||
|
||||
const nextChild = document.createElement('strong')
|
||||
popup.children = [nextChild]
|
||||
popup.classList = 'extra'
|
||||
|
||||
expect(popup.element.firstElementChild).toBe(nextChild)
|
||||
expect(popup.element).toHaveClass('comfyui-popup', 'left', 'extra')
|
||||
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Enter' }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const escape = new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
cancelable: true
|
||||
})
|
||||
window.dispatchEvent(escape)
|
||||
|
||||
expect(close).toHaveBeenCalledOnce()
|
||||
expect(escape.defaultPrevented).toBe(true)
|
||||
expect(popup.open).toBe(false)
|
||||
})
|
||||
|
||||
it('handles outside clicks, target clicks, and relative right positioning', () => {
|
||||
const target = targetWithRect(
|
||||
DOMRect.fromRect({ x: 100, y: 40, width: 60, height: 25 })
|
||||
)
|
||||
const container = document.createElement('section')
|
||||
document.body.append(container)
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
container,
|
||||
position: 'relative',
|
||||
horizontal: 'right'
|
||||
})
|
||||
vi.spyOn(popup.element, 'getBoundingClientRect').mockReturnValue(
|
||||
DOMRect.fromRect({ height: 100 })
|
||||
)
|
||||
Object.defineProperty(popup.element, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 40
|
||||
})
|
||||
|
||||
popup.open = true
|
||||
target.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
expect(popup.open).toBe(true)
|
||||
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
|
||||
expect(popup.open).toBe(false)
|
||||
expect(popup.element.style.getPropertyValue('--left')).toBe('0px')
|
||||
expect(popup.element.style.getPropertyValue('--top')).toBe('25px')
|
||||
})
|
||||
|
||||
it('keeps outside clicks open when target clicks are not ignored', () => {
|
||||
const target = targetWithRect(DOMRect.fromRect({ height: 10 }))
|
||||
const popup = new ComfyPopup({
|
||||
target,
|
||||
ignoreTarget: false,
|
||||
closeOnEscape: false
|
||||
})
|
||||
const outside = document.createElement('div')
|
||||
document.body.append(outside)
|
||||
|
||||
popup.toggle()
|
||||
outside.dispatchEvent(new MouseEvent('click', { bubbles: true }))
|
||||
window.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' }))
|
||||
|
||||
expect(popup.open).toBe(true)
|
||||
})
|
||||
|
||||
it('renders split button popup items and updates button groups', () => {
|
||||
const primary = new ComfyButton({ content: 'Queue' })
|
||||
const itemButton = new ComfyButton({ content: 'Queue front' })
|
||||
const rawItem = document.createElement('button')
|
||||
rawItem.textContent = 'Queue back'
|
||||
const split = new ComfySplitButton(
|
||||
{
|
||||
primary,
|
||||
mode: 'hover',
|
||||
horizontal: 'right',
|
||||
position: 'absolute'
|
||||
},
|
||||
itemButton,
|
||||
rawItem
|
||||
)
|
||||
|
||||
expect(split.element).toHaveClass('comfyui-split-button', 'hover')
|
||||
expect(split.popup.element).toHaveTextContent('Queue front')
|
||||
expect(split.popup.element).toHaveTextContent('Queue back')
|
||||
expect(document.body).toContainElement(split.popup.element)
|
||||
|
||||
const group = new ComfyButtonGroup(primary)
|
||||
group.append(itemButton)
|
||||
group.insert(fromAny(rawItem), 1)
|
||||
|
||||
expect(group.element.children).toHaveLength(3)
|
||||
expect(group.remove(itemButton)).toEqual([itemButton])
|
||||
expect(group.remove(itemButton)).toBeUndefined()
|
||||
expect(group.element.children).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('resolves async dialogs from buttons, close events, and prompt actions', async () => {
|
||||
const dialog = new ComfyAsyncDialog<number>([
|
||||
{ text: 'Seven', value: 7 },
|
||||
'Fallback'
|
||||
])
|
||||
|
||||
const promise = dialog.show('Pick one')
|
||||
dialog.element.querySelector<HTMLButtonElement>('button')?.click()
|
||||
await expect(promise).resolves.toBe(7)
|
||||
|
||||
const closePromise = dialog.show(document.createElement('em'))
|
||||
dialog.element.dispatchEvent(new Event('close'))
|
||||
await expect(closePromise).resolves.toBeNull()
|
||||
|
||||
const promptPromise = ComfyAsyncDialog.prompt({
|
||||
title: 'Confirm',
|
||||
message: 'Continue?',
|
||||
actions: [{ text: 'Yes', value: 'yes' }]
|
||||
})
|
||||
const prompt = Array.from(document.querySelectorAll('dialog')).at(-1)
|
||||
expect(prompt).toHaveTextContent('Confirm')
|
||||
prompt?.querySelector<HTMLButtonElement>('button')?.click()
|
||||
|
||||
await expect(promptPromise).resolves.toBe('yes')
|
||||
expect(document.body).not.toContainElement(prompt ?? null)
|
||||
})
|
||||
})
|
||||
69
src/scripts/ui/dialog.test.ts
Normal file
69
src/scripts/ui/dialog.test.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { ComfyDialog } from './dialog'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[],
|
||||
maybeChildren?: Node[]
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: maybeChildren
|
||||
|
||||
if (propsOrChildren && !Array.isArray(propsOrChildren)) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
if (key === 'parent' && value instanceof Node) {
|
||||
value.appendChild(element)
|
||||
} else if (key === '$' && typeof value === 'function') {
|
||||
value(element)
|
||||
} else {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...(children ?? []))
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('ComfyDialog', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('shows string and element content and closes through the default button', () => {
|
||||
const dialog = new ComfyDialog()
|
||||
|
||||
dialog.show('<strong>Hello</strong>')
|
||||
|
||||
expect(dialog.element.style.display).toBe('flex')
|
||||
expect(dialog.textElement.innerHTML).toBe('<strong>Hello</strong>')
|
||||
|
||||
dialog.element.querySelector('button')?.click()
|
||||
expect(dialog.element.style.display).toBe('none')
|
||||
|
||||
const first = document.createElement('span')
|
||||
const second = document.createElement('em')
|
||||
dialog.show([first, second])
|
||||
|
||||
expect([...dialog.textElement.children]).toEqual([first, second])
|
||||
})
|
||||
|
||||
it('uses supplied custom buttons', () => {
|
||||
const button = document.createElement('button')
|
||||
const dialog = new ComfyDialog('section', [button])
|
||||
|
||||
expect(dialog.element.tagName).toBe('SECTION')
|
||||
expect(dialog.element.querySelector('button')).toBe(button)
|
||||
})
|
||||
})
|
||||
216
src/scripts/ui/draggableList.test.ts
Normal file
216
src/scripts/ui/draggableList.test.ts
Normal file
@@ -0,0 +1,216 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DraggableList } from './draggableList'
|
||||
|
||||
function createList(itemCount: number) {
|
||||
const container = document.createElement('div')
|
||||
const items = Array.from({ length: itemCount }, (_, index) => {
|
||||
const item = document.createElement('div')
|
||||
item.className = 'item'
|
||||
item.dataset.index = String(index)
|
||||
|
||||
const handle = document.createElement('button')
|
||||
handle.className = 'drag-handle'
|
||||
item.append(handle)
|
||||
container.append(item)
|
||||
|
||||
return item
|
||||
})
|
||||
document.body.append(container)
|
||||
return { container, items }
|
||||
}
|
||||
|
||||
function setRect(element: Element, top: number, height = 20) {
|
||||
return vi
|
||||
.spyOn(element, 'getBoundingClientRect')
|
||||
.mockReturnValue(new DOMRect(0, top, 100, height))
|
||||
}
|
||||
|
||||
function defineScrollMetrics(
|
||||
container: HTMLElement,
|
||||
scrollHeight: number,
|
||||
clientHeight: number
|
||||
) {
|
||||
Object.defineProperty(container, 'scrollHeight', {
|
||||
configurable: true,
|
||||
value: scrollHeight
|
||||
})
|
||||
Object.defineProperty(container, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: clientHeight
|
||||
})
|
||||
}
|
||||
|
||||
function mouseDragEvent(
|
||||
target: Element,
|
||||
overrides: Partial<MouseEvent> = {}
|
||||
): MouseEvent {
|
||||
return {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target,
|
||||
...overrides
|
||||
} satisfies Partial<MouseEvent> as MouseEvent
|
||||
}
|
||||
|
||||
describe('DraggableList', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
vi.restoreAllMocks()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('ignores missing containers and non-primary drag starts', () => {
|
||||
const listWithoutContainer = new DraggableList(null, '.item')
|
||||
const { container, items } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[0].querySelector('.drag-handle')!, { button: 1 })
|
||||
)
|
||||
list.dragEnd()
|
||||
|
||||
expect(listWithoutContainer.listContainer).toBeNull()
|
||||
expect(list.draggableItem).toBeUndefined()
|
||||
})
|
||||
|
||||
it('starts from a handle, scrolls downward, and reorders upward', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 0, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const dragStart = vi.fn()
|
||||
const dragEnd = vi.fn()
|
||||
list.addEventListener('dragstart', dragStart)
|
||||
list.addEventListener('dragend', dragEnd)
|
||||
|
||||
list.dragStart(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 10,
|
||||
clientY: 70
|
||||
})
|
||||
)
|
||||
list.drag(
|
||||
mouseDragEvent(items[2].querySelector('.drag-handle')!, {
|
||||
clientX: 20,
|
||||
clientY: 100
|
||||
})
|
||||
)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(dragStart).toHaveBeenCalledOnce()
|
||||
expect(dragEnd).toHaveBeenCalledOnce()
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, 10)
|
||||
expect([...container.children]).toEqual([items[0], items[2], items[1]])
|
||||
expect(items[0].classList.contains('is-idle')).toBe(true)
|
||||
expect(items[1].classList.contains('is-idle')).toBe(true)
|
||||
})
|
||||
|
||||
it('supports touch coordinates, upward scrolling, and downward reorder', () => {
|
||||
vi.stubGlobal('requestAnimationFrame', (callback: FrameRequestCallback) => {
|
||||
callback(0)
|
||||
return 1
|
||||
})
|
||||
const { container, items } = createList(3)
|
||||
const scrollBy = vi.fn((_left: number, top: number) => {
|
||||
container.scrollTop += top
|
||||
})
|
||||
container.scrollTop = 10
|
||||
container.scrollBy = scrollBy as unknown as typeof container.scrollBy
|
||||
defineScrollMetrics(container, 120, 80)
|
||||
vi.spyOn(container, 'getBoundingClientRect').mockReturnValue(
|
||||
new DOMRect(0, 20, 100, 80)
|
||||
)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[1], 30)
|
||||
setRect(items[2], 60)
|
||||
|
||||
const list = new DraggableList(container, '.item')
|
||||
const touchStart = {
|
||||
button: 0,
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 5, clientY: 30 }]
|
||||
} as unknown as TouchEvent
|
||||
const touchMove = {
|
||||
clientX: 0,
|
||||
clientY: 0,
|
||||
preventDefault: vi.fn(),
|
||||
target: items[0].querySelector('.drag-handle')!,
|
||||
touches: [{ clientX: 8, clientY: 10 }]
|
||||
} as unknown as TouchEvent
|
||||
|
||||
list.dragStart(touchStart)
|
||||
list.drag(touchMove)
|
||||
items[1].dataset.isToggled = ''
|
||||
list.dragEnd()
|
||||
|
||||
expect(scrollBy).toHaveBeenCalledWith(0, -10)
|
||||
expect([...container.children]).toEqual([items[1], items[0], items[2]])
|
||||
})
|
||||
|
||||
it('updates idle item state around the dragged item midpoint', () => {
|
||||
const { container, items } = createList(3)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const state = list as unknown as {
|
||||
items: HTMLElement[]
|
||||
draggableItem: HTMLElement
|
||||
}
|
||||
state.items = items
|
||||
state.draggableItem = items[1]
|
||||
list.itemsGap = 5
|
||||
items[0].classList.add('is-idle')
|
||||
items[1].classList.add('is-idle')
|
||||
items[2].classList.add('is-idle')
|
||||
items[0].dataset.isAbove = ''
|
||||
const draggedRect = setRect(items[1], -10)
|
||||
setRect(items[0], 0)
|
||||
setRect(items[2], 60)
|
||||
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBe('')
|
||||
expect(items[0].style.transform).toBe('translateY(25px)')
|
||||
expect(items[2].style.transform).toBe('')
|
||||
|
||||
draggedRect.mockReturnValue(new DOMRect(0, 100, 100, 20))
|
||||
list.updateIdleItemsStateAndPosition()
|
||||
|
||||
expect(items[0].dataset.isToggled).toBeUndefined()
|
||||
expect(items[2].dataset.isToggled).toBe('')
|
||||
expect(items[2].style.transform).toBe('translateY(-25px)')
|
||||
})
|
||||
|
||||
it('uses zero gap for short lists and disposes listeners', () => {
|
||||
const { container } = createList(1)
|
||||
const list = new DraggableList(container, '.item')
|
||||
const off = vi.fn()
|
||||
const disposableList = list as unknown as { off: Array<() => void> }
|
||||
disposableList.off = [off]
|
||||
|
||||
list.setItemsGap()
|
||||
list.dispose()
|
||||
|
||||
expect(list.itemsGap).toBe(0)
|
||||
expect(off).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
33
src/scripts/ui/imagePreview.test.ts
Normal file
33
src/scripts/ui/imagePreview.test.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { calculateImageGrid } from './imagePreview'
|
||||
|
||||
function createImage(width: number, height: number) {
|
||||
const img = document.createElement('img')
|
||||
Object.defineProperty(img, 'naturalWidth', {
|
||||
configurable: true,
|
||||
value: width
|
||||
})
|
||||
Object.defineProperty(img, 'naturalHeight', {
|
||||
configurable: true,
|
||||
value: height
|
||||
})
|
||||
return img
|
||||
}
|
||||
|
||||
describe('imagePreview', () => {
|
||||
it('calculates the highest-area grid', () => {
|
||||
const images = [
|
||||
createImage(100, 100),
|
||||
createImage(100, 100),
|
||||
createImage(100, 100)
|
||||
]
|
||||
|
||||
expect(calculateImageGrid(images, 300, 120)).toMatchObject({
|
||||
cellWidth: 100,
|
||||
cellHeight: 100,
|
||||
cols: 3,
|
||||
rows: 1
|
||||
})
|
||||
})
|
||||
})
|
||||
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
86
src/scripts/ui/toggleSwitch.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { toggleSwitch } from './toggleSwitch'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
$el: (
|
||||
selector: string,
|
||||
propsOrChildren?: Record<string, unknown> | Node[] | Node,
|
||||
maybeChildren?: Node[] | Node
|
||||
) => {
|
||||
const [tag, ...classes] = selector.split('.')
|
||||
const element = document.createElement(tag || 'div')
|
||||
element.classList.add(...classes.filter(Boolean))
|
||||
const children = Array.isArray(propsOrChildren)
|
||||
? propsOrChildren
|
||||
: propsOrChildren instanceof Node
|
||||
? [propsOrChildren]
|
||||
: Array.isArray(maybeChildren)
|
||||
? maybeChildren
|
||||
: maybeChildren instanceof Node
|
||||
? [maybeChildren]
|
||||
: []
|
||||
|
||||
if (
|
||||
propsOrChildren &&
|
||||
!(propsOrChildren instanceof Node) &&
|
||||
!Array.isArray(propsOrChildren)
|
||||
) {
|
||||
for (const [key, value] of Object.entries(propsOrChildren)) {
|
||||
Reflect.set(element, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
element.append(...children)
|
||||
return element
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
describe('toggleSwitch', () => {
|
||||
afterEach(() => {
|
||||
document.body.replaceChildren()
|
||||
})
|
||||
|
||||
it('selects the first item when none is preselected', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch('mode', ['first', 'second'], { onChange })
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const inputs = [...container.querySelectorAll('input')]
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect((inputs[0] as HTMLInputElement).checked).toBe(true)
|
||||
expect(onChange).toHaveBeenCalledWith({
|
||||
item: 'first',
|
||||
prev: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('moves selection and reports the previous item', () => {
|
||||
const onChange = vi.fn()
|
||||
const container = toggleSwitch(
|
||||
'mode',
|
||||
[
|
||||
{ text: 'first', tooltip: 'First option' },
|
||||
{ text: 'second', value: '2' }
|
||||
],
|
||||
{ onChange }
|
||||
)
|
||||
const labels = [...container.querySelectorAll('label')]
|
||||
const secondInput = labels[1].querySelector('input') as HTMLInputElement
|
||||
|
||||
secondInput.onchange?.(new Event('change'))
|
||||
|
||||
expect(labels[0].classList.contains('comfy-toggle-selected')).toBe(false)
|
||||
expect(labels[1].classList.contains('comfy-toggle-selected')).toBe(true)
|
||||
expect(labels[0].title).toBe('First option')
|
||||
expect(secondInput.value).toBe('2')
|
||||
expect(onChange).toHaveBeenLastCalledWith({
|
||||
item: { text: 'second', value: '2' },
|
||||
prev: { text: 'first', tooltip: 'First option', value: 'first' }
|
||||
})
|
||||
})
|
||||
})
|
||||
45
src/scripts/ui/utils.test.ts
Normal file
45
src/scripts/ui/utils.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { applyClasses, toggleElement } from './utils'
|
||||
|
||||
describe('ui utils', () => {
|
||||
it('applies string, array, object, and required classes', () => {
|
||||
const element = document.createElement('div')
|
||||
|
||||
applyClasses(element, 'one two', 'required')
|
||||
expect([...element.classList]).toEqual(['one', 'two', 'required'])
|
||||
|
||||
applyClasses(element, ['three', 'four'])
|
||||
expect([...element.classList]).toEqual(['three', 'four'])
|
||||
|
||||
applyClasses(element, { five: true, six: false, seven: true })
|
||||
expect([...element.classList]).toEqual(['five', 'seven'])
|
||||
|
||||
applyClasses(element, null as unknown as string)
|
||||
expect(element.className).toBe('')
|
||||
})
|
||||
|
||||
it('toggles an element through a placeholder', () => {
|
||||
const parent = document.createElement('div')
|
||||
const element = document.createElement('span')
|
||||
const onHide = vi.fn()
|
||||
const onShow = vi.fn()
|
||||
parent.append(element)
|
||||
const toggle = toggleElement(element, { onHide, onShow })
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledWith(element)
|
||||
|
||||
toggle(true)
|
||||
expect(parent.firstChild).toBe(element)
|
||||
expect(onShow).toHaveBeenCalledWith(element, true)
|
||||
|
||||
toggle('visible')
|
||||
expect(onShow).toHaveBeenCalledWith(element, 'visible')
|
||||
|
||||
toggle(false)
|
||||
expect(parent.firstChild).toBeInstanceOf(Comment)
|
||||
expect(onHide).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
127
src/scripts/utils.test.ts
Normal file
127
src/scripts/utils.test.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
addStylesheet,
|
||||
clone,
|
||||
getStorageValue,
|
||||
prop,
|
||||
setStorageValue
|
||||
} from './utils'
|
||||
|
||||
interface LinkAttrs {
|
||||
href: string
|
||||
onerror: (error: Event) => void
|
||||
onload: () => void
|
||||
parent: HTMLElement
|
||||
rel: string
|
||||
type: string
|
||||
}
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
api: {
|
||||
clientId: null as string | null,
|
||||
initialClientId: null as string | null
|
||||
},
|
||||
$el: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('./api', () => ({
|
||||
api: mocks.api
|
||||
}))
|
||||
|
||||
vi.mock('./ui', () => ({
|
||||
$el: mocks.$el
|
||||
}))
|
||||
|
||||
function lastLinkAttrs() {
|
||||
return mocks.$el.mock.calls.at(-1)?.[1] as LinkAttrs
|
||||
}
|
||||
|
||||
describe('scripts utils', () => {
|
||||
afterEach(() => {
|
||||
localStorage.clear()
|
||||
sessionStorage.clear()
|
||||
mocks.api.clientId = null
|
||||
mocks.api.initialClientId = null
|
||||
mocks.$el.mockReset()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('clones with structuredClone and falls back to JSON cloning', () => {
|
||||
const source = { nested: { value: 1 } }
|
||||
|
||||
expect(clone(source)).toEqual(source)
|
||||
|
||||
vi.stubGlobal(
|
||||
'structuredClone',
|
||||
vi.fn(() => {
|
||||
throw new Error('unsupported')
|
||||
})
|
||||
)
|
||||
|
||||
const cloned = clone(source)
|
||||
cloned.nested.value = 2
|
||||
|
||||
expect(cloned).toEqual({ nested: { value: 2 } })
|
||||
expect(source).toEqual({ nested: { value: 1 } })
|
||||
})
|
||||
|
||||
it('adds stylesheets from script and relative URLs', async () => {
|
||||
const scriptPromise = addStylesheet('/extensions/example.js')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(scriptPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs()).toMatchObject({
|
||||
href: '/extensions/example.css',
|
||||
parent: document.head,
|
||||
rel: 'stylesheet',
|
||||
type: 'text/css'
|
||||
})
|
||||
|
||||
const cssPromise = addStylesheet('theme.css', 'https://example.com/base/')
|
||||
lastLinkAttrs().onload()
|
||||
|
||||
await expect(cssPromise).resolves.toBeUndefined()
|
||||
expect(lastLinkAttrs().href).toBe('https://example.com/base/theme.css')
|
||||
})
|
||||
|
||||
it('rejects when stylesheet loading fails', async () => {
|
||||
const promise = addStylesheet('missing.css', 'https://example.com/')
|
||||
const error = new Event('error')
|
||||
lastLinkAttrs().onerror(error)
|
||||
|
||||
await expect(promise).rejects.toBe(error)
|
||||
})
|
||||
|
||||
it('defines an observable property with the supplied default', () => {
|
||||
const target = {}
|
||||
const onChanged = vi.fn()
|
||||
|
||||
expect(prop(target, 'mode', 'initial', onChanged)).toBe('initial')
|
||||
Object.assign(target, { mode: 'next' })
|
||||
|
||||
expect((target as { mode: string }).mode).toBe('next')
|
||||
expect(onChanged).toHaveBeenCalledWith('next', undefined, target, 'mode')
|
||||
})
|
||||
|
||||
it('uses client-scoped storage before local fallback', () => {
|
||||
mocks.api.clientId = 'client-1'
|
||||
setStorageValue('setting', 'client-value')
|
||||
sessionStorage.removeItem('setting:client-1')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('client-value')
|
||||
expect(localStorage.getItem('setting')).toBe('client-value')
|
||||
|
||||
sessionStorage.setItem('setting:client-1', 'session-value')
|
||||
|
||||
expect(getStorageValue('setting')).toBe('session-value')
|
||||
})
|
||||
|
||||
it('uses initial client id when the current client id is unavailable', () => {
|
||||
mocks.api.initialClientId = 'initial-1'
|
||||
setStorageValue('setting', 'initial-value')
|
||||
|
||||
expect(sessionStorage.getItem('setting:initial-1')).toBe('initial-value')
|
||||
expect(getStorageValue('setting')).toBe('initial-value')
|
||||
})
|
||||
})
|
||||
356
src/scripts/widgets.test.ts
Normal file
356
src/scripts/widgets.test.ts
Normal file
@@ -0,0 +1,356 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type {
|
||||
IBaseWidget,
|
||||
IComboWidget
|
||||
} from '@/lib/litegraph/src/types/widgets'
|
||||
import type { InputSpec } from '@/schemas/nodeDefSchema'
|
||||
|
||||
const mockSettingGet = vi.hoisted(() => vi.fn())
|
||||
const mockNextValueForLinkedTarget = vi.hoisted(() => vi.fn())
|
||||
const mockIsComboWidget = vi.hoisted(() => vi.fn())
|
||||
const mockTransformInputSpecV1ToV2 = vi.hoisted(() => vi.fn())
|
||||
|
||||
function v2WidgetConstructor(kind: string) {
|
||||
return () => (_node: LGraphNode, inputSpec: { name: string }) => ({
|
||||
name: `${kind}:${inputSpec.name}`,
|
||||
options: { minNodeSize: [20, 30] }
|
||||
})
|
||||
}
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => `translated:${key}`
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockSettingGet
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/litegraph/src/litegraph', () => ({
|
||||
isComboWidget: mockIsComboWidget
|
||||
}))
|
||||
|
||||
vi.mock('./valueControl', () => ({
|
||||
nextValueForLinkedTarget: mockNextValueForLinkedTarget
|
||||
}))
|
||||
|
||||
vi.mock('@/schemas/nodeDef/migration', () => ({
|
||||
transformInputSpecV1ToV2: mockTransformInputSpecV1ToV2
|
||||
}))
|
||||
|
||||
vi.mock('@/core/graph/widgets/dynamicWidgets', () => ({
|
||||
dynamicWidgets: {
|
||||
DYNAMIC: () => ({
|
||||
widget: { name: 'dynamic', options: {} },
|
||||
minWidth: 1,
|
||||
minHeight: 1
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBooleanWidget',
|
||||
() => ({ useBooleanWidget: v2WidgetConstructor('BOOLEAN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxWidget',
|
||||
() => ({ useBoundingBoxWidget: v2WidgetConstructor('BOUNDING_BOX') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useCurveWidget',
|
||||
() => ({ useCurveWidget: v2WidgetConstructor('CURVE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useChartWidget',
|
||||
() => ({ useChartWidget: v2WidgetConstructor('CHART') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorWidget',
|
||||
() => ({ useColorWidget: v2WidgetConstructor('COLOR') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useComboWidget',
|
||||
() => ({ useComboWidget: v2WidgetConstructor('COMBO') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useFloatWidget',
|
||||
() => ({ useFloatWidget: v2WidgetConstructor('FLOAT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useGalleriaWidget',
|
||||
() => ({ useGalleriaWidget: v2WidgetConstructor('GALLERIA') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useBoundingBoxesWidget',
|
||||
() => ({ useBoundingBoxesWidget: v2WidgetConstructor('BOUNDING_BOXES') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useColorsWidget',
|
||||
() => ({ useColorsWidget: v2WidgetConstructor('COLORS') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageCompareWidget',
|
||||
() => ({ useImageCompareWidget: v2WidgetConstructor('IMAGECOMPARE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useImageUploadWidget',
|
||||
() => ({
|
||||
useImageUploadWidget: () => (_node: LGraphNode, inputName: string) => ({
|
||||
widget: { name: `IMAGEUPLOAD:${inputName}`, options: {} },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
})
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useIntWidget',
|
||||
() => ({ useIntWidget: v2WidgetConstructor('INT') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useMarkdownWidget',
|
||||
() => ({ useMarkdownWidget: v2WidgetConstructor('MARKDOWN') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/usePainterWidget',
|
||||
() => ({ usePainterWidget: v2WidgetConstructor('PAINTER') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useRangeWidget',
|
||||
() => ({ useRangeWidget: v2WidgetConstructor('RANGE') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useStringWidget',
|
||||
() => ({ useStringWidget: v2WidgetConstructor('STRING') })
|
||||
)
|
||||
vi.mock(
|
||||
'@/renderer/extensions/vueNodes/widgets/composables/useTextareaWidget',
|
||||
() => ({ useTextareaWidget: v2WidgetConstructor('TEXTAREA') })
|
||||
)
|
||||
|
||||
vi.mock('./domWidget', () => ({}))
|
||||
vi.mock('./errorNodeWidgets', () => ({}))
|
||||
|
||||
import {
|
||||
ComfyWidgets,
|
||||
IS_CONTROL_WIDGET,
|
||||
addValueControlWidget,
|
||||
addValueControlWidgets,
|
||||
isValidWidgetType,
|
||||
updateControlWidgetLabel
|
||||
} from './widgets'
|
||||
|
||||
// `linkedWidgets`, `beforeQueued`, and `afterQueued` already exist on
|
||||
// IBaseWidget (via the litegraph augmentation), so no extra members needed.
|
||||
type MockWidget = IBaseWidget
|
||||
|
||||
function makeTargetWidget(overrides: Partial<MockWidget> = {}): MockWidget {
|
||||
return {
|
||||
name: 'seed',
|
||||
value: 1,
|
||||
callback: vi.fn(),
|
||||
options: {},
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false,
|
||||
...overrides
|
||||
} as MockWidget
|
||||
}
|
||||
|
||||
function makeNode(inputs: LGraphNode['inputs'] = []) {
|
||||
const widgets: MockWidget[] = []
|
||||
const node = {
|
||||
id: 42,
|
||||
inputs,
|
||||
addWidget: vi.fn(
|
||||
(
|
||||
type: string,
|
||||
name: string,
|
||||
value: string,
|
||||
callback: () => void,
|
||||
options: Record<string, unknown>
|
||||
) => {
|
||||
const widget: MockWidget = fromAny({
|
||||
type,
|
||||
name,
|
||||
value,
|
||||
callback,
|
||||
options,
|
||||
linkedWidgets: [],
|
||||
computedDisabled: false
|
||||
})
|
||||
widgets.push(widget)
|
||||
return widget
|
||||
}
|
||||
)
|
||||
}
|
||||
return { node: node as unknown as LGraphNode, widgets }
|
||||
}
|
||||
|
||||
describe('widgets', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
mockNextValueForLinkedTarget.mockReturnValue('next')
|
||||
mockIsComboWidget.mockImplementation(
|
||||
(widget: MockWidget) => widget.type === 'combo'
|
||||
)
|
||||
mockTransformInputSpecV1ToV2.mockImplementation(
|
||||
(_inputData: InputSpec, options: { name: string }) => ({
|
||||
name: options.name
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('updates the control widget label from the configured run mode', () => {
|
||||
const widget = makeTargetWidget()
|
||||
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_before_generate')
|
||||
|
||||
mockSettingGet.mockReturnValue('after')
|
||||
updateControlWidgetLabel(widget)
|
||||
expect(widget.label).toBe('translated:g.control_after_generate')
|
||||
})
|
||||
|
||||
it('adds control and filter widgets for combo targets', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo', computedDisabled: true })
|
||||
|
||||
const result = addValueControlWidgets(node, target, '', undefined, [
|
||||
'COMBO',
|
||||
{
|
||||
control_prefix: 'custom'
|
||||
}
|
||||
] as unknown as InputSpec)
|
||||
|
||||
expect(result).toHaveLength(2)
|
||||
expect(widgets[0].name).toBe('custom control_after_generate')
|
||||
expect(widgets[0].value).toBe('randomize')
|
||||
expect((widgets[0] as IComboWidget).options.values).toContain(
|
||||
'increment-wrap'
|
||||
)
|
||||
expect(widgets[0][IS_CONTROL_WIDGET]).toBe(true)
|
||||
expect(widgets[0].disabled).toBe(true)
|
||||
expect(widgets[1].name).toBe('custom control_filter_list')
|
||||
expect(widgets[1].disabled).toBe(true)
|
||||
})
|
||||
|
||||
it('uses explicit option names and can skip the combo filter widget', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget({ type: 'combo' })
|
||||
|
||||
addValueControlWidgets(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
{
|
||||
addFilterList: false,
|
||||
controlAfterGenerateName: 'mode'
|
||||
},
|
||||
['COMBO', {}] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(widgets).toHaveLength(1)
|
||||
expect(widgets[0].name).toBe('mode')
|
||||
})
|
||||
|
||||
it('applies linked target values after queueing in after mode', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.({ isPartialExecution: true })
|
||||
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith({
|
||||
target,
|
||||
linkedWidgets: target.linkedWidgets,
|
||||
nodeId: 42,
|
||||
isPartialExecution: true
|
||||
})
|
||||
expect(target.value).toBe('next')
|
||||
expect(target.callback).toHaveBeenCalledWith('next')
|
||||
})
|
||||
|
||||
it('waits until the second beforeQueued call in before mode', () => {
|
||||
mockSettingGet.mockReturnValue('before')
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].beforeQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
widgets[0].beforeQueued?.({ isPartialExecution: false })
|
||||
expect(mockNextValueForLinkedTarget).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ isPartialExecution: false })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not change the target when the target has a linked input or no next value', () => {
|
||||
const { node, widgets } = makeNode([
|
||||
{ widget: { name: 'seed' }, link: 1 }
|
||||
] as LGraphNode['inputs'])
|
||||
const target = makeTargetWidget()
|
||||
|
||||
addValueControlWidgets(node, target)
|
||||
widgets[0].afterQueued?.()
|
||||
expect(mockNextValueForLinkedTarget).not.toHaveBeenCalled()
|
||||
|
||||
const unlinked = makeNode()
|
||||
mockNextValueForLinkedTarget.mockReturnValue(undefined)
|
||||
addValueControlWidgets(unlinked.node, target)
|
||||
unlinked.widgets[0].afterQueued?.()
|
||||
expect(target.callback).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the legacy single control widget name from input data before widgetName', () => {
|
||||
const { node, widgets } = makeNode()
|
||||
const target = makeTargetWidget()
|
||||
|
||||
const result = addValueControlWidget(
|
||||
node,
|
||||
target,
|
||||
'fixed',
|
||||
undefined,
|
||||
'fallback',
|
||||
[
|
||||
'INT',
|
||||
{
|
||||
control_after_generate: 'from_input_data'
|
||||
}
|
||||
] as unknown as InputSpec
|
||||
)
|
||||
|
||||
expect(result).toBe(widgets[0])
|
||||
expect(widgets[0].name).toBe('from_input_data')
|
||||
})
|
||||
|
||||
it('exposes transformed widget constructors and type validation', () => {
|
||||
const { node } = makeNode()
|
||||
|
||||
const intWidget = ComfyWidgets.INT(
|
||||
node,
|
||||
'value',
|
||||
['INT', {}] as unknown as InputSpec,
|
||||
{} as never
|
||||
)
|
||||
|
||||
expect(intWidget.widget.name).toBe('INT:value')
|
||||
expect(intWidget.minWidth).toBe(20)
|
||||
expect(intWidget.minHeight).toBe(30)
|
||||
expect(
|
||||
ComfyWidgets.IMAGEUPLOAD(node, 'image', ['IMAGE', {}], {} as never)
|
||||
).toMatchObject({
|
||||
widget: { name: 'IMAGEUPLOAD:image' },
|
||||
minWidth: 5,
|
||||
minHeight: 6
|
||||
})
|
||||
expect(isValidWidgetType('INT')).toBe(true)
|
||||
expect(isValidWidgetType('DYNAMIC')).toBe(true)
|
||||
expect(isValidWidgetType('missing')).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -38,6 +38,12 @@ describe('useAudioService', () => {
|
||||
name: 'test-audio-123.wav'
|
||||
}
|
||||
|
||||
async function freshService() {
|
||||
vi.resetModules()
|
||||
const audioServiceModule = await import('@/services/audioService')
|
||||
return audioServiceModule.useAudioService()
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
@@ -90,12 +96,41 @@ describe('useAudioService', () => {
|
||||
)
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
await service.registerWavEncoder()
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(mockConnect).toHaveBeenCalledTimes(0)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(0)
|
||||
expect(mockConnect).toHaveBeenCalledTimes(1)
|
||||
expect(mockRegister).toHaveBeenCalledTimes(1)
|
||||
expect(console.error).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should log encoder registration errors', async () => {
|
||||
const error = new Error('Encoder failed')
|
||||
mockRegister.mockRejectedValueOnce(error)
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
error
|
||||
)
|
||||
})
|
||||
|
||||
it('should log non-Error encoder registration failures', async () => {
|
||||
mockRegister.mockRejectedValueOnce('Encoder failed')
|
||||
|
||||
const isolatedService = await freshService()
|
||||
await isolatedService.registerWavEncoder()
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
'Audio Service Error (encoder):',
|
||||
'Failed to register WAV encoder',
|
||||
'Encoder failed'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('stopAllTracks', () => {
|
||||
|
||||
118
src/services/autoQueueService.test.ts
Normal file
118
src/services/autoQueueService.test.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { setupAutoQueueHandler } from '@/services/autoQueueService'
|
||||
|
||||
type ApiEvent = 'graphChanged'
|
||||
type ApiListener = () => void
|
||||
type Subscription = () => Promise<void> | void
|
||||
|
||||
const {
|
||||
listeners,
|
||||
queueCountStore,
|
||||
queueSettingsStore,
|
||||
appState,
|
||||
addEventListener,
|
||||
isInstantRunningMode
|
||||
} = vi.hoisted(() => ({
|
||||
listeners: new Map<ApiEvent, ApiListener>(),
|
||||
queueCountStore: {
|
||||
count: 0,
|
||||
subscription: undefined as Subscription | undefined,
|
||||
$subscribe: vi.fn((_callback: Subscription) => {
|
||||
queueCountStore.subscription = _callback
|
||||
})
|
||||
},
|
||||
queueSettingsStore: {
|
||||
mode: 'manual',
|
||||
batchCount: 1
|
||||
},
|
||||
appState: {
|
||||
lastExecutionError: null as unknown,
|
||||
queuePrompt: vi.fn()
|
||||
},
|
||||
addEventListener: vi.fn((event: ApiEvent, listener: ApiListener) => {
|
||||
listeners.set(event, listener)
|
||||
}),
|
||||
isInstantRunningMode: vi.fn((mode: string) => mode === 'instant')
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { addEventListener }
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: appState
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/queueStore', () => ({
|
||||
isInstantRunningMode,
|
||||
useQueuePendingTaskCountStore: () => queueCountStore,
|
||||
useQueueSettingsStore: () => queueSettingsStore
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
listeners.clear()
|
||||
queueCountStore.count = 0
|
||||
queueCountStore.subscription = undefined
|
||||
queueCountStore.$subscribe.mockClear()
|
||||
queueSettingsStore.mode = 'manual'
|
||||
queueSettingsStore.batchCount = 1
|
||||
appState.lastExecutionError = null
|
||||
appState.queuePrompt.mockReset().mockResolvedValue(undefined)
|
||||
addEventListener.mockClear()
|
||||
isInstantRunningMode
|
||||
.mockClear()
|
||||
.mockImplementation((mode) => mode === 'instant')
|
||||
})
|
||||
|
||||
describe('setupAutoQueueHandler', () => {
|
||||
it('queues immediately on graph changes when change mode is idle', () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueSettingsStore.batchCount = 3
|
||||
|
||||
setupAutoQueueHandler()
|
||||
listeners.get('graphChanged')?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 3)
|
||||
})
|
||||
|
||||
it('queues after pending work drains in instant mode', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueSettingsStore.batchCount = 2
|
||||
queueCountStore.count = 0
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledWith(0, 2)
|
||||
})
|
||||
|
||||
it('queues after a changed graph drains from an active queue', async () => {
|
||||
queueSettingsStore.mode = 'change'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
listeners.get('graphChanged')?.()
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('does not requeue while work remains or the last run failed', async () => {
|
||||
queueSettingsStore.mode = 'instant'
|
||||
queueCountStore.count = 1
|
||||
|
||||
setupAutoQueueHandler()
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
appState.lastExecutionError = { message: 'failed' }
|
||||
queueCountStore.count = 0
|
||||
await queueCountStore.subscription?.()
|
||||
|
||||
expect(appState.queuePrompt).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
363
src/services/colorPaletteService.test.ts
Normal file
363
src/services/colorPaletteService.test.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DEFAULT_DARK_COLOR_PALETTE } from '@/constants/coreColorPalettes'
|
||||
import {
|
||||
LGraphCanvas,
|
||||
LiteGraph,
|
||||
RenderShape
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { CompletedPalette, Palette } from '@/schemas/colorPaletteSchema'
|
||||
|
||||
const mockCanvas = vi.hoisted(() => ({
|
||||
default_connection_color_byType: {} as Record<string, string>,
|
||||
node_title_color: '',
|
||||
default_link_color: '',
|
||||
background_image: '',
|
||||
clear_background_color: '',
|
||||
_pattern: 'pattern' as string | undefined,
|
||||
setDirty: vi.fn()
|
||||
}))
|
||||
|
||||
const mockColorPaletteStore = vi.hoisted(() => ({
|
||||
customPalettes: {} as Record<string, unknown>,
|
||||
palettesLookup: {} as Record<string, unknown>,
|
||||
completedActivePalette: undefined as unknown,
|
||||
activePaletteId: 'dark',
|
||||
addCustomPalette: vi.fn(),
|
||||
deleteCustomPalette: vi.fn(),
|
||||
completePalette: vi.fn()
|
||||
}))
|
||||
|
||||
const mockSettingStore = vi.hoisted(() => ({
|
||||
get: vi.fn(),
|
||||
set: vi.fn()
|
||||
}))
|
||||
|
||||
const mockNodeDefStore = vi.hoisted(() => ({
|
||||
nodeDataTypes: new Set(['IMAGE', 'MISSING'])
|
||||
}))
|
||||
|
||||
const mockDownloadBlob = vi.hoisted(() => vi.fn())
|
||||
const mockUploadFile = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: { canvas: mockCanvas }
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/workspace/colorPaletteStore', () => ({
|
||||
useColorPaletteStore: () => mockColorPaletteStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => mockSettingStore
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => mockNodeDefStore
|
||||
}))
|
||||
|
||||
vi.mock('@/base/common/downloadUtil', () => ({
|
||||
downloadBlob: mockDownloadBlob
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/utils', () => ({
|
||||
uploadFile: mockUploadFile
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling: <T>(action: T) => action,
|
||||
wrapWithErrorHandlingAsync: <T>(action: T) => action
|
||||
})
|
||||
}))
|
||||
|
||||
import { useColorPaletteService } from './colorPaletteService'
|
||||
|
||||
const validCustomPalette = {
|
||||
id: 'custom',
|
||||
name: 'Custom',
|
||||
colors: {
|
||||
node_slot: {},
|
||||
litegraph_base: {},
|
||||
comfy_base: {}
|
||||
}
|
||||
} satisfies Palette
|
||||
|
||||
function makeCompletedPalette(id = 'custom'): CompletedPalette {
|
||||
const palette = structuredClone(
|
||||
DEFAULT_DARK_COLOR_PALETTE
|
||||
) as CompletedPalette
|
||||
palette.id = id
|
||||
palette.name = 'Custom'
|
||||
palette.colors.node_slot.IMAGE = '#123456'
|
||||
palette.colors.litegraph_base.NODE_TITLE_COLOR = '#abcdef'
|
||||
palette.colors.litegraph_base.LINK_COLOR = '#fedcba'
|
||||
palette.colors.litegraph_base.BACKGROUND_IMAGE = 'grid.png'
|
||||
palette.colors.litegraph_base.CLEAR_BACKGROUND_COLOR = '#010203'
|
||||
palette.colors.litegraph_base.NODE_DEFAULT_SHAPE = 'legacy'
|
||||
palette.colors.comfy_base['fg-color'] = '#111111'
|
||||
palette.colors.comfy_base['bg-color'] = '#222222'
|
||||
delete palette.colors.comfy_base['contrast-mix-color']
|
||||
return palette
|
||||
}
|
||||
|
||||
describe('useColorPaletteService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCanvas.default_connection_color_byType = {}
|
||||
mockCanvas.node_title_color = ''
|
||||
mockCanvas.default_link_color = ''
|
||||
mockCanvas.background_image = ''
|
||||
mockCanvas.clear_background_color = ''
|
||||
mockCanvas._pattern = 'pattern'
|
||||
LGraphCanvas.link_type_colors = {}
|
||||
mockSettingStore.get.mockReturnValue('')
|
||||
mockSettingStore.set.mockResolvedValue(undefined)
|
||||
mockColorPaletteStore.customPalettes = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.palettesLookup = { custom: validCustomPalette }
|
||||
mockColorPaletteStore.completedActivePalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.activePaletteId = 'dark'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette()
|
||||
)
|
||||
document.documentElement.style.cssText = ''
|
||||
document.documentElement.style.setProperty(
|
||||
'--color-datatype-MISSING',
|
||||
'#ffffff'
|
||||
)
|
||||
})
|
||||
|
||||
it('adds valid custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.addCustomColorPalette(validCustomPalette)
|
||||
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects invalid custom palettes before mutating the store', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.addCustomColorPalette({} as Palette)).rejects.toThrow(
|
||||
'Invalid color palette against zod schema'
|
||||
)
|
||||
expect(mockColorPaletteStore.addCustomPalette).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes custom palettes and persists the custom palette map', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.deleteCustomColorPalette('custom')
|
||||
|
||||
expect(mockColorPaletteStore.deleteCustomPalette).toHaveBeenCalledWith(
|
||||
'custom'
|
||||
)
|
||||
expect(mockSettingStore.set).toHaveBeenCalledWith(
|
||||
'Comfy.CustomColorPalettes',
|
||||
mockColorPaletteStore.customPalettes
|
||||
)
|
||||
})
|
||||
|
||||
it('loads palette colors into litegraph, Vue CSS variables, and canvas state', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.default_connection_color_byType.IMAGE).toBe('#123456')
|
||||
expect(LGraphCanvas.link_type_colors.IMAGE).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--color-datatype-IMAGE')
|
||||
).toBe('#123456')
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue(
|
||||
'--color-datatype-MISSING'
|
||||
)
|
||||
).toBe('')
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.default_link_color).toBe('#fedcba')
|
||||
expect(mockCanvas.background_image).toBe('grid.png')
|
||||
expect(mockCanvas.clear_background_color).toBe('#010203')
|
||||
expect(mockCanvas._pattern).toBeUndefined()
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.ROUND)
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
`litegraph_base.NODE_DEFAULT_SHAPE only accepts [${[
|
||||
RenderShape.BOX,
|
||||
RenderShape.ROUND,
|
||||
RenderShape.CARD
|
||||
].join(', ')}] but got legacy`
|
||||
)
|
||||
expect(document.documentElement.style.getPropertyValue('--fg-color')).toBe(
|
||||
'#111111'
|
||||
)
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('var(--palette-contrast-mix-color)')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('skips absent palette sections while still activating the palette', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
fromAny<CompletedPalette, unknown>({
|
||||
...completedPalette,
|
||||
colors: {
|
||||
node_slot: undefined,
|
||||
litegraph_base: completedPalette.colors.litegraph_base,
|
||||
comfy_base: undefined
|
||||
}
|
||||
})
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.node_title_color).toBe('#abcdef')
|
||||
expect(mockCanvas.setDirty).toHaveBeenCalledWith(true, true)
|
||||
expect(mockColorPaletteStore.activePaletteId).toBe('custom')
|
||||
})
|
||||
|
||||
it('removes Vue node theme overrides for built-in palettes', async () => {
|
||||
mockColorPaletteStore.palettesLookup = { dark: validCustomPalette }
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(
|
||||
makeCompletedPalette('dark')
|
||||
)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('dark')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('removes Vue node theme variables when completed palette values are absent', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
// NODE_BOX_OUTLINE_COLOR is required on the completed palette type; the
|
||||
// test needs it absent, so delete via Reflect to keep the type intact.
|
||||
Reflect.deleteProperty(
|
||||
completedPalette.colors.litegraph_base,
|
||||
'NODE_BOX_OUTLINE_COLOR'
|
||||
)
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
document.documentElement.style.setProperty(
|
||||
'--component-node-border',
|
||||
'#ffffff'
|
||||
)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--component-node-border')
|
||||
).toBe('')
|
||||
})
|
||||
|
||||
it('preserves numeric LiteGraph node shapes without warning', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.litegraph_base.NODE_DEFAULT_SHAPE = RenderShape.CARD
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(LiteGraph.NODE_DEFAULT_SHAPE).toBe(RenderShape.CARD)
|
||||
expect(warn).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses explicit optional comfy color values when present', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.colors.comfy_base['contrast-mix-color'] = '#333333'
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(
|
||||
document.documentElement.style.getPropertyValue('--contrast-mix-color')
|
||||
).toBe('#333333')
|
||||
})
|
||||
|
||||
it('uses a white splash background for light themes', async () => {
|
||||
const completedPalette = makeCompletedPalette()
|
||||
completedPalette.light_theme = true
|
||||
mockColorPaletteStore.completePalette.mockReturnValue(completedPalette)
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(localStorage.getItem('comfy-splash-bg')).toBe('#FFFFFF')
|
||||
expect(localStorage.getItem('comfy-splash-fg')).toBe('#111111')
|
||||
})
|
||||
|
||||
it('uses transparent canvas background and bg image CSS when a background image setting exists', async () => {
|
||||
mockSettingStore.get.mockReturnValue('/custom/background.png')
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await service.loadColorPalette('custom')
|
||||
|
||||
expect(mockCanvas.clear_background_color).toBe('transparent')
|
||||
expect(document.documentElement.style.getPropertyValue('--bg-img')).toBe(
|
||||
"url('/custom/background.png')"
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when loading or exporting an unknown palette', async () => {
|
||||
mockColorPaletteStore.palettesLookup = {}
|
||||
const service = useColorPaletteService()
|
||||
|
||||
await expect(service.loadColorPalette('missing')).rejects.toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
expect(() => service.exportColorPalette('missing')).toThrow(
|
||||
'Color palette missing not found'
|
||||
)
|
||||
})
|
||||
|
||||
it('exports palette JSON by id', async () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
service.exportColorPalette('custom')
|
||||
|
||||
expect(mockDownloadBlob).toHaveBeenCalledOnce()
|
||||
const [filename, blob] = mockDownloadBlob.mock.calls[0] as [string, Blob]
|
||||
expect(filename).toBe('custom.json')
|
||||
await expect(blob.text()).resolves.toContain('"id": "custom"')
|
||||
})
|
||||
|
||||
it('imports palette JSON through the custom palette path', async () => {
|
||||
mockUploadFile.mockResolvedValue({
|
||||
text: () => Promise.resolve(JSON.stringify(validCustomPalette))
|
||||
})
|
||||
const service = useColorPaletteService()
|
||||
|
||||
const palette = await service.importColorPalette()
|
||||
|
||||
expect(mockUploadFile).toHaveBeenCalledWith('application/json')
|
||||
expect(palette).toEqual(validCustomPalette)
|
||||
expect(mockColorPaletteStore.addCustomPalette).toHaveBeenCalledWith(
|
||||
validCustomPalette
|
||||
)
|
||||
})
|
||||
|
||||
it('returns the completed active palette from the store', () => {
|
||||
const service = useColorPaletteService()
|
||||
|
||||
expect(service.getActiveColorPalette()).toBe(
|
||||
mockColorPaletteStore.completedActivePalette
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -18,7 +18,7 @@ import type {
|
||||
} from '@/stores/dialogStore'
|
||||
|
||||
import type { ComponentAttrs } from 'vue-component-type-helpers'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
// Lazy loaders for dialogs - components are loaded on first use
|
||||
@@ -442,9 +442,9 @@ export const useDialogService = () => {
|
||||
})
|
||||
}
|
||||
|
||||
async function showSubscriptionRequiredDialog(options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) {
|
||||
async function showSubscriptionRequiredDialog(
|
||||
options?: SubscriptionDialogOptions
|
||||
) {
|
||||
if (!isCloud || !window.__CONFIG__?.subscription_required) {
|
||||
return
|
||||
}
|
||||
|
||||
324
src/services/extensionService.test.ts
Normal file
324
src/services/extensionService.test.ts
Normal file
@@ -0,0 +1,324 @@
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { AuthUserInfo } from '@/types/authTypes'
|
||||
import type { ComfyExtension } from '@/types/comfy'
|
||||
import { useExtensionService } from './extensionService'
|
||||
|
||||
const mockLoadDisabledExtensionNames = vi.hoisted(() => vi.fn())
|
||||
const mockRegisterExtension = vi.hoisted(() => vi.fn())
|
||||
const mockCaptureCoreExtensions = vi.hoisted(() => vi.fn())
|
||||
const mockEnabledExtensions = vi.hoisted(() => ({
|
||||
value: [] as ComfyExtension[]
|
||||
}))
|
||||
vi.mock('@/stores/extensionStore', () => ({
|
||||
useExtensionStore: () => ({
|
||||
loadDisabledExtensionNames: mockLoadDisabledExtensionNames,
|
||||
registerExtension: mockRegisterExtension,
|
||||
captureCoreExtensions: mockCaptureCoreExtensions,
|
||||
get enabledExtensions() {
|
||||
return mockEnabledExtensions.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockGetSetting = vi.hoisted(() => vi.fn())
|
||||
const mockAddSetting = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
get: mockGetSetting,
|
||||
addSetting: mockAddSetting
|
||||
})
|
||||
}))
|
||||
|
||||
const mockAddDefaultKeybinding = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/keybindings/keybindingStore', () => ({
|
||||
useKeybindingStore: () => ({
|
||||
addDefaultKeybinding: mockAddDefaultKeybinding
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/keybindings/keybinding', () => ({
|
||||
KeybindingImpl: class KeybindingImpl {
|
||||
constructor(readonly source: unknown) {}
|
||||
}
|
||||
}))
|
||||
|
||||
const mockLoadExtensionCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({
|
||||
loadExtensionCommands: mockLoadExtensionCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockLoadExtensionMenuCommands = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/menuItemStore', () => ({
|
||||
useMenuItemStore: () => ({
|
||||
loadExtensionMenuCommands: mockLoadExtensionMenuCommands
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterBottomPanelTabs = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/workspace/bottomPanelStore', () => ({
|
||||
useBottomPanelStore: () => ({
|
||||
registerExtensionBottomPanelTabs: mockRegisterBottomPanelTabs
|
||||
})
|
||||
}))
|
||||
|
||||
const mockRegisterCustomWidgets = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({
|
||||
registerCustomWidgets: mockRegisterCustomWidgets
|
||||
})
|
||||
}))
|
||||
|
||||
const mockToastErrorHandler = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/composables/useErrorHandling', () => ({
|
||||
useErrorHandling: () => ({
|
||||
wrapWithErrorHandling:
|
||||
<Args extends unknown[], Return>(fn: (...args: Args) => Return) =>
|
||||
(...args: Args) =>
|
||||
fn(...args),
|
||||
wrapWithErrorHandlingAsync:
|
||||
<Args extends unknown[], Return>(
|
||||
fn: (...args: Args) => Return | Promise<Return>,
|
||||
handler: (error: unknown) => void
|
||||
) =>
|
||||
async (...args: Args) => {
|
||||
try {
|
||||
return await fn(...args)
|
||||
} catch (error) {
|
||||
handler(error)
|
||||
}
|
||||
},
|
||||
toastErrorHandler: mockToastErrorHandler
|
||||
})
|
||||
}))
|
||||
|
||||
const mockUserResolvedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<(user: AuthUserInfo) => void>
|
||||
}))
|
||||
const mockTokenRefreshedCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
const mockUserLogoutCallbacks = vi.hoisted(() => ({
|
||||
values: [] as Array<() => void>
|
||||
}))
|
||||
vi.mock('@/composables/auth/useCurrentUser', () => ({
|
||||
useCurrentUser: () => ({
|
||||
onUserResolved: (callback: (user: AuthUserInfo) => void) => {
|
||||
mockUserResolvedCallbacks.values.push(callback)
|
||||
},
|
||||
onTokenRefreshed: (callback: () => void) => {
|
||||
mockTokenRefreshedCallbacks.values.push(callback)
|
||||
},
|
||||
onUserLogout: (callback: () => void) => {
|
||||
mockUserLogoutCallbacks.values.push(callback)
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
const mockSetCurrentExtension = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/lib/litegraph/src/contextMenuCompat', () => ({
|
||||
legacyMenuCompat: {
|
||||
setCurrentExtension: mockSetCurrentExtension
|
||||
}
|
||||
}))
|
||||
|
||||
const mockApp = vi.hoisted(() => ({ value: { name: 'app' } }))
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: mockApp.value
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
getExtensions: vi.fn(),
|
||||
fileURL: vi.fn((path: string) => path)
|
||||
}
|
||||
}))
|
||||
|
||||
describe('useExtensionService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockEnabledExtensions.value = []
|
||||
mockUserResolvedCallbacks.values = []
|
||||
mockTokenRefreshedCallbacks.values = []
|
||||
mockUserLogoutCallbacks.values = []
|
||||
})
|
||||
|
||||
it('registers extension contributions across stores', async () => {
|
||||
const widgets = { CustomWidget: vi.fn() }
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'registration-extension',
|
||||
keybindings: [{ commandId: 'command.one', combo: { key: 'K' } }],
|
||||
commands: [{ id: 'command.one', label: 'Command One' }],
|
||||
menuCommands: [{ path: ['File'], commands: ['command.one'] }],
|
||||
settings: [{ id: 'setting.one', name: 'Setting One' }],
|
||||
bottomPanelTabs: [{ id: 'tab.one', title: 'Tab One' }],
|
||||
getCustomWidgets: vi.fn().mockResolvedValue(widgets)
|
||||
})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
|
||||
expect(mockRegisterExtension).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddDefaultKeybinding).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
source: { commandId: 'command.one', combo: { key: 'K' } }
|
||||
})
|
||||
)
|
||||
expect(mockLoadExtensionCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockLoadExtensionMenuCommands).toHaveBeenCalledWith(extension)
|
||||
expect(mockAddSetting.mock.calls[0][0]).toEqual({
|
||||
id: 'setting.one',
|
||||
name: 'Setting One'
|
||||
})
|
||||
expect(mockRegisterBottomPanelTabs).toHaveBeenCalledWith(extension)
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRegisterCustomWidgets).toHaveBeenCalledWith(widgets)
|
||||
})
|
||||
})
|
||||
|
||||
it('invokes auth lifecycle hooks through registered callbacks', async () => {
|
||||
const onAuthUserResolved = vi.fn()
|
||||
const onAuthTokenRefreshed = vi.fn()
|
||||
const onAuthUserLogout = vi.fn()
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'auth-extension',
|
||||
onAuthUserResolved,
|
||||
onAuthTokenRefreshed,
|
||||
onAuthUserLogout
|
||||
})
|
||||
const user = fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](user)
|
||||
mockTokenRefreshedCallbacks.values[0]()
|
||||
mockUserLogoutCallbacks.values[0]()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAuthUserResolved).toHaveBeenCalledWith(user, mockApp.value)
|
||||
expect(onAuthTokenRefreshed).toHaveBeenCalled()
|
||||
expect(onAuthUserLogout).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('reports auth hook errors through the toast handler', async () => {
|
||||
const error = new Error('auth failed')
|
||||
const extension = fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-auth-extension',
|
||||
onAuthUserResolved: vi.fn(() => {
|
||||
throw error
|
||||
})
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const service = useExtensionService()
|
||||
|
||||
service.registerExtension(extension)
|
||||
mockUserResolvedCallbacks.values[0](
|
||||
fromAny<AuthUserInfo, unknown>({ id: 'user-1' })
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastErrorHandler).toHaveBeenCalledWith(error)
|
||||
})
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
'[Extension Auth Hook Error]',
|
||||
expect.objectContaining({
|
||||
extension: 'failing-auth-extension',
|
||||
hook: 'onAuthUserResolved',
|
||||
error
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('invokes synchronous extension methods and keeps failures isolated', () => {
|
||||
const getSelectionToolboxCommands = vi.fn(() => ['command.one'])
|
||||
const failingGetSelectionToolboxCommands = vi.fn(() => {
|
||||
throw new Error('menu failed')
|
||||
})
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'working-extension',
|
||||
getSelectionToolboxCommands
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
getSelectionToolboxCommands: ['not callable']
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-extension',
|
||||
getSelectionToolboxCommands: failingGetSelectionToolboxCommands
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = service.invokeExtensions(
|
||||
'getSelectionToolboxCommands',
|
||||
fromAny<LGraphNode, unknown>({ id: 1 })
|
||||
)
|
||||
|
||||
expect(results).toEqual([['command.one']])
|
||||
expect(getSelectionToolboxCommands).toHaveBeenCalledWith(
|
||||
fromAny<LGraphNode, unknown>({ id: 1 }),
|
||||
mockApp.value
|
||||
)
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-extension' method 'getSelectionToolboxCommands'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-extension' })
|
||||
}),
|
||||
expect.objectContaining({
|
||||
args: [fromAny<LGraphNode, unknown>({ id: 1 })]
|
||||
})
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('tracks current extension around async setup callbacks', async () => {
|
||||
const setup = vi.fn().mockResolvedValue('setup-result')
|
||||
const failingSetup = vi.fn().mockRejectedValue(new Error('setup failed'))
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
mockEnabledExtensions.value = [
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'setup-extension',
|
||||
setup
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'non-function-extension',
|
||||
setup: true
|
||||
}),
|
||||
fromAny<ComfyExtension, unknown>({
|
||||
name: 'failing-setup-extension',
|
||||
setup: failingSetup
|
||||
}),
|
||||
{ name: 'missing-method-extension' }
|
||||
]
|
||||
const service = useExtensionService()
|
||||
|
||||
const results = await service.invokeExtensionsAsync('setup')
|
||||
|
||||
expect(results).toEqual(['setup-result', undefined, undefined, undefined])
|
||||
expect(mockSetCurrentExtension.mock.calls.map((call) => call[0])).toEqual([
|
||||
'setup-extension',
|
||||
'failing-setup-extension',
|
||||
null,
|
||||
null
|
||||
])
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Error calling extension 'failing-setup-extension' method 'setup'",
|
||||
expect.objectContaining({ error: expect.any(Error) }),
|
||||
expect.objectContaining({
|
||||
extension: expect.objectContaining({ name: 'failing-setup-extension' })
|
||||
}),
|
||||
expect.objectContaining({ args: [] })
|
||||
)
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,35 +1,181 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useMediaCache } from './mediaCacheService'
|
||||
import type { useMediaCache } from './mediaCacheService'
|
||||
|
||||
// Mock fetch
|
||||
global.fetch = vi.fn()
|
||||
global.URL = {
|
||||
createObjectURL: vi.fn(() => 'blob:mock-url'),
|
||||
revokeObjectURL: vi.fn()
|
||||
} as Partial<typeof URL> as typeof URL
|
||||
type MediaCache = ReturnType<typeof useMediaCache>
|
||||
|
||||
describe('mediaCacheService', () => {
|
||||
describe('URL reference counting', () => {
|
||||
it('should handle URL acquisition for non-existent cache entry', () => {
|
||||
const { acquireUrl } = useMediaCache()
|
||||
const mockFetch = vi.fn()
|
||||
const mockCreateObjectURL = vi.fn()
|
||||
const mockRevokeObjectURL = vi.fn()
|
||||
const NativeURL = URL
|
||||
|
||||
const url = acquireUrl('non-existent.jpg')
|
||||
expect(url).toBeUndefined()
|
||||
class MockURL extends NativeURL {
|
||||
static override createObjectURL(blob: Blob): string {
|
||||
return mockCreateObjectURL(blob)
|
||||
}
|
||||
|
||||
static override revokeObjectURL(url: string): void {
|
||||
mockRevokeObjectURL(url)
|
||||
}
|
||||
}
|
||||
|
||||
function response(ok: boolean, blob = new Blob(['image'])): Response {
|
||||
return {
|
||||
ok,
|
||||
status: ok ? 200 : 404,
|
||||
blob: () => Promise.resolve(blob)
|
||||
} as Response
|
||||
}
|
||||
|
||||
async function freshCache(options?: {
|
||||
maxSize?: number
|
||||
maxAge?: number
|
||||
}): Promise<MediaCache> {
|
||||
vi.resetModules()
|
||||
const { useMediaCache } = await import('./mediaCacheService')
|
||||
return useMediaCache(options)
|
||||
}
|
||||
|
||||
describe('useMediaCache', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.setSystemTime(0)
|
||||
mockFetch.mockReset()
|
||||
mockCreateObjectURL.mockReset()
|
||||
mockRevokeObjectURL.mockReset()
|
||||
mockCreateObjectURL.mockImplementation(
|
||||
(_blob: Blob) => `blob:${mockCreateObjectURL.mock.calls.length}`
|
||||
)
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
vi.stubGlobal('URL', MockURL)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
window.dispatchEvent(new Event('beforeunload'))
|
||||
vi.useRealTimers()
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('fetches media once and returns cached entries on later reads', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
|
||||
const first = await cache.getCachedMedia('/image.png')
|
||||
vi.setSystemTime(100)
|
||||
const second = await cache.getCachedMedia('/image.png')
|
||||
|
||||
expect(first).toMatchObject({
|
||||
src: '/image.png',
|
||||
objectUrl: 'blob:1',
|
||||
isLoading: false
|
||||
})
|
||||
|
||||
it('should handle URL release for non-existent cache entry', () => {
|
||||
const { releaseUrl } = useMediaCache()
|
||||
|
||||
// Should not throw error
|
||||
expect(() => releaseUrl('non-existent.jpg')).not.toThrow()
|
||||
})
|
||||
|
||||
it('should provide acquireUrl and releaseUrl methods', () => {
|
||||
const cache = useMediaCache()
|
||||
|
||||
expect(typeof cache.acquireUrl).toBe('function')
|
||||
expect(typeof cache.releaseUrl).toBe('function')
|
||||
expect(second).toEqual(first)
|
||||
expect(second.lastAccessed).toBe(100)
|
||||
expect(mockFetch).toHaveBeenCalledOnce()
|
||||
expect(mockFetch).toHaveBeenCalledWith('/image.png', {
|
||||
cache: 'force-cache'
|
||||
})
|
||||
})
|
||||
|
||||
it('stores an error entry when fetch fails', async () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockFetch.mockResolvedValue(response(false))
|
||||
const cache = await freshCache()
|
||||
|
||||
const entry = await cache.getCachedMedia('/missing.png')
|
||||
|
||||
expect(entry).toMatchObject({
|
||||
src: '/missing.png',
|
||||
error: true,
|
||||
isLoading: false
|
||||
})
|
||||
expect(warn).toHaveBeenCalledWith(
|
||||
'Failed to cache media:',
|
||||
'/missing.png',
|
||||
expect.any(Error)
|
||||
)
|
||||
})
|
||||
|
||||
it('ref-counts acquired object URLs and removes the cache entry on final release', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
await cache.getCachedMedia('/image.png')
|
||||
|
||||
expect(cache.acquireUrl('/image.png')).toBe('blob:1')
|
||||
expect(cache.acquireUrl('/image.png')).toBe('blob:1')
|
||||
cache.releaseUrl('/image.png')
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
expect(cache.cache.has('/image.png')).toBe(true)
|
||||
|
||||
cache.releaseUrl('/image.png')
|
||||
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
expect(cache.cache.has('/image.png')).toBe(false)
|
||||
})
|
||||
|
||||
it('returns undefined when acquiring a URL that is not cached', async () => {
|
||||
const cache = await freshCache()
|
||||
|
||||
expect(cache.acquireUrl('/missing.png')).toBeUndefined()
|
||||
cache.releaseUrl('/missing.png')
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('expires old cache entries during scheduled cleanup', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxAge: 100 })
|
||||
await cache.getCachedMedia('/old.png')
|
||||
|
||||
vi.setSystemTime(200)
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/old.png')).toBe(false)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
})
|
||||
|
||||
it('keeps expired entries while their object URL is still acquired', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxAge: 100 })
|
||||
await cache.getCachedMedia('/held.png')
|
||||
cache.acquireUrl('/held.png')
|
||||
|
||||
vi.setSystemTime(200)
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/held.png')).toBe(true)
|
||||
expect(mockRevokeObjectURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('removes the oldest unused entries when the cache is over size', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache({ maxSize: 1, maxAge: 1_000_000 })
|
||||
await cache.getCachedMedia('/old.png')
|
||||
vi.setSystemTime(1)
|
||||
await cache.getCachedMedia('/new.png')
|
||||
|
||||
vi.advanceTimersByTime(5 * 60 * 1000)
|
||||
|
||||
expect(cache.cache.has('/old.png')).toBe(false)
|
||||
expect(cache.cache.has('/new.png')).toBe(true)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
})
|
||||
|
||||
it('clears all cached URLs on demand and before unload', async () => {
|
||||
mockFetch.mockResolvedValue(response(true))
|
||||
const cache = await freshCache()
|
||||
await cache.getCachedMedia('/first.png')
|
||||
await cache.getCachedMedia('/second.png')
|
||||
|
||||
cache.clearCache()
|
||||
|
||||
expect(cache.cache.size).toBe(0)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:1')
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:2')
|
||||
|
||||
await cache.getCachedMedia('/third.png')
|
||||
window.dispatchEvent(new Event('beforeunload'))
|
||||
|
||||
expect(cache.cache.size).toBe(0)
|
||||
expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:3')
|
||||
})
|
||||
})
|
||||
|
||||
162
src/services/nodeHelpService.test.ts
Normal file
162
src/services/nodeHelpService.test.ts
Normal file
@@ -0,0 +1,162 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
import { nodeHelpService } from '@/services/nodeHelpService'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
fileURL: vi.fn((path: string) => `/files${path}`)
|
||||
}
|
||||
}))
|
||||
|
||||
describe('nodeHelpService', () => {
|
||||
const mockFetch = vi.fn<typeof fetch>()
|
||||
|
||||
function nodeDef(options: {
|
||||
name?: string
|
||||
python_module?: string
|
||||
description?: string
|
||||
}): ComfyNodeDefImpl {
|
||||
return {
|
||||
name: options.name ?? 'TestNode',
|
||||
display_name: options.name ?? 'Test Node',
|
||||
category: 'test',
|
||||
python_module: options.python_module ?? 'nodes',
|
||||
description: options.description ?? ''
|
||||
} as ComfyNodeDefImpl
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFetch.mockReset()
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
})
|
||||
|
||||
it('returns blueprint descriptions without fetching markdown', async () => {
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
python_module: 'blueprint',
|
||||
description: 'Saved workflow help'
|
||||
}),
|
||||
'en'
|
||||
)
|
||||
|
||||
expect(help).toBe('Saved workflow help')
|
||||
expect(mockFetch).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('fetches localized core node markdown', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('Core help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({ name: 'LoadImage' }),
|
||||
'zh'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenCalledWith('/docs/LoadImage/zh.md')
|
||||
expect(mockFetch).toHaveBeenCalledWith('/files/docs/LoadImage/zh.md')
|
||||
expect(help).toBe('Core help')
|
||||
})
|
||||
|
||||
it('rejects core node HTML fallbacks', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('<html></html>', {
|
||||
headers: { 'content-type': 'text/html' },
|
||||
statusText: 'OK'
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(nodeDef({ name: 'PreviewImage' }), 'en')
|
||||
).rejects.toThrow('OK')
|
||||
})
|
||||
|
||||
it('uses the default missing-help error for empty markdown responses', async () => {
|
||||
mockFetch.mockResolvedValueOnce(new Response(''))
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(nodeDef({ name: 'EmptyHelp' }), 'en')
|
||||
).rejects.toThrow('Help not found')
|
||||
})
|
||||
|
||||
it('fetches custom node localized markdown before fallback markdown', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
new Response('Custom localized help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.ComfyUI-TestPack@1.2.3.nodes'
|
||||
}),
|
||||
'ja'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenCalledWith(
|
||||
'/extensions/ComfyUI-TestPack/docs/CustomNode/ja.md'
|
||||
)
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1)
|
||||
expect(help).toBe('Custom localized help')
|
||||
})
|
||||
|
||||
it('falls back to custom node default markdown after locale miss', async () => {
|
||||
mockFetch
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Not found', {
|
||||
status: 404,
|
||||
statusText: 'Locale missing'
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Custom fallback help', {
|
||||
headers: { 'content-type': 'text/markdown' }
|
||||
})
|
||||
)
|
||||
|
||||
const help = await nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.TestPack.nodes'
|
||||
}),
|
||||
'fr'
|
||||
)
|
||||
|
||||
expect(api.fileURL).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'/extensions/TestPack/docs/CustomNode/fr.md'
|
||||
)
|
||||
expect(api.fileURL).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'/extensions/TestPack/docs/CustomNode.md'
|
||||
)
|
||||
expect(help).toBe('Custom fallback help')
|
||||
})
|
||||
|
||||
it('reports the locale miss when the custom fallback is empty', async () => {
|
||||
mockFetch
|
||||
.mockResolvedValueOnce(
|
||||
new Response('Not found', {
|
||||
status: 404,
|
||||
statusText: 'Locale missing'
|
||||
})
|
||||
)
|
||||
.mockResolvedValueOnce(new Response(''))
|
||||
|
||||
await expect(
|
||||
nodeHelpService.fetchNodeHelp(
|
||||
nodeDef({
|
||||
name: 'CustomNode',
|
||||
python_module: 'custom_nodes.TestPack.nodes'
|
||||
}),
|
||||
'de'
|
||||
)
|
||||
).rejects.toThrow('Locale missing')
|
||||
})
|
||||
})
|
||||
@@ -124,6 +124,80 @@ describe('nodeOrganizationService', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('organizeNodesTab', () => {
|
||||
it('returns no sections for an empty node list', () => {
|
||||
expect(nodeOrganizationService.organizeNodesTab([])).toEqual([])
|
||||
})
|
||||
|
||||
it('classifies blueprints, partner nodes, Comfy nodes, and extensions', () => {
|
||||
const sections = nodeOrganizationService.organizeNodesTab([
|
||||
createMockNodeDef({
|
||||
name: 'MyBlueprint',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Blueprint,
|
||||
className: 'blueprint',
|
||||
displayText: 'Blueprint',
|
||||
badgeText: 'B'
|
||||
},
|
||||
isGlobal: false
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'ComfyBlueprint',
|
||||
python_module: 'blueprint.comfy',
|
||||
isGlobal: true
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'PartnerApi',
|
||||
category: 'api node/image'
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'CoreNode',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Core,
|
||||
className: 'core',
|
||||
displayText: 'Core',
|
||||
badgeText: 'C'
|
||||
}
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'EssentialNode',
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Essentials,
|
||||
className: 'essentials',
|
||||
displayText: 'Essentials',
|
||||
badgeText: 'E'
|
||||
}
|
||||
}),
|
||||
createMockNodeDef({
|
||||
name: 'ExtensionNode'
|
||||
})
|
||||
])
|
||||
|
||||
expect(sections.map((section) => section.category)).toEqual([
|
||||
'blueprints',
|
||||
'comfyNodes',
|
||||
'partnerNodes',
|
||||
'extensions'
|
||||
])
|
||||
expect(sections[0].tree.children?.map((child) => child.key)).toEqual([
|
||||
'root/my-blueprints',
|
||||
'root/comfy-blueprints'
|
||||
])
|
||||
})
|
||||
|
||||
it('omits sections that have no matching nodes', () => {
|
||||
const sections = nodeOrganizationService.organizeNodesTab([
|
||||
createMockNodeDef({
|
||||
name: 'OnlyExtension'
|
||||
})
|
||||
])
|
||||
|
||||
expect(sections.map((section) => section.category)).toEqual([
|
||||
'extensions'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getGroupingIcon', () => {
|
||||
it('should return strategy icon', () => {
|
||||
const icon = nodeOrganizationService.getGroupingIcon('category')
|
||||
@@ -216,6 +290,12 @@ describe('nodeOrganizationService', () => {
|
||||
expect(path).toEqual(['core', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle custom_nodes without a package segment', () => {
|
||||
const nodeDef = createMockNodeDef({ python_module: 'custom_nodes' })
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['custom_nodes', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle non-standard module paths', () => {
|
||||
const nodeDef = createMockNodeDef({
|
||||
python_module: 'some.other.module.path'
|
||||
@@ -282,6 +362,19 @@ describe('nodeOrganizationService', () => {
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['Unknown', 'TestNode'])
|
||||
})
|
||||
|
||||
it('should handle core source type', () => {
|
||||
const nodeDef = createMockNodeDef({
|
||||
nodeSource: {
|
||||
type: NodeSourceType.Core,
|
||||
className: 'core',
|
||||
displayText: 'Core',
|
||||
badgeText: 'C'
|
||||
}
|
||||
})
|
||||
const path = strategy?.getNodePath(nodeDef)
|
||||
expect(path).toEqual(['Core', 'TestNode'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('node name edge cases', () => {
|
||||
@@ -326,5 +419,14 @@ describe('nodeOrganizationService', () => {
|
||||
expect(strategy?.compare(nodeA, nodeB)).toBeGreaterThan(0)
|
||||
expect(strategy?.compare(nodeB, nodeA)).toBeLessThan(0)
|
||||
})
|
||||
|
||||
it('alphabetical sort handles missing display names', () => {
|
||||
const strategy =
|
||||
nodeOrganizationService.getSortingStrategy('alphabetical')
|
||||
const nodeA = createMockNodeDef({ display_name: undefined })
|
||||
const nodeB = createMockNodeDef({ display_name: undefined })
|
||||
|
||||
expect(strategy?.compare(nodeA, nodeB)).toBe(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
172
src/services/subgraphService.test.ts
Normal file
172
src/services/subgraphService.test.ts
Normal file
@@ -0,0 +1,172 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type {
|
||||
ExportedSubgraph,
|
||||
ExportedSubgraphInstance,
|
||||
Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
addNodeDef: vi.fn(),
|
||||
createSubgraph: vi.fn((subgraph: unknown) => ({
|
||||
createdFrom: subgraph
|
||||
})),
|
||||
registerSubgraphNodeDef: vi.fn(),
|
||||
subgraphs: new Map<string, unknown>()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => ({
|
||||
addNodeDef: mocks.addNodeDef
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
rootGraph: {
|
||||
subgraphs: mocks.subgraphs,
|
||||
createSubgraph: mocks.createSubgraph
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./litegraphService', () => ({
|
||||
useLitegraphService: () => ({
|
||||
registerSubgraphNodeDef: mocks.registerSubgraphNodeDef
|
||||
})
|
||||
}))
|
||||
|
||||
const { useSubgraphService } = await import('./subgraphService')
|
||||
|
||||
function createExportedSubgraph(
|
||||
overrides: Partial<ExportedSubgraph> = {}
|
||||
): ExportedSubgraph {
|
||||
return {
|
||||
id: 'subgraph-1',
|
||||
name: 'Test Subgraph',
|
||||
...overrides
|
||||
} as ExportedSubgraph
|
||||
}
|
||||
|
||||
function createWorkflow(subgraphs?: ExportedSubgraph[]): ComfyWorkflowJSON {
|
||||
return {
|
||||
definitions: subgraphs
|
||||
? {
|
||||
subgraphs
|
||||
}
|
||||
: undefined
|
||||
} as unknown as ComfyWorkflowJSON
|
||||
}
|
||||
|
||||
describe('useSubgraphService', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mocks.subgraphs.clear()
|
||||
})
|
||||
|
||||
it('registers a new subgraph node definition', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'runtime-subgraph' } as unknown as Subgraph
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
|
||||
service.registerNewSubgraph(subgraph, exportedSubgraph)
|
||||
|
||||
expect(mocks.addNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1',
|
||||
display_name: 'Test Subgraph',
|
||||
description: 'Subgraph node for Test Subgraph',
|
||||
category: 'subgraph',
|
||||
output_node: false,
|
||||
python_module: 'nodes'
|
||||
})
|
||||
)
|
||||
|
||||
const [nodeDef, registeredSubgraph, instanceData] = mocks
|
||||
.registerSubgraphNodeDef.mock.calls[0] as [
|
||||
ComfyNodeDefV1,
|
||||
Subgraph,
|
||||
ExportedSubgraphInstance
|
||||
]
|
||||
|
||||
expect(nodeDef.name).toBe('subgraph-1')
|
||||
expect(registeredSubgraph).toBe(subgraph)
|
||||
expect(instanceData).toMatchObject({
|
||||
id: -1,
|
||||
type: 'subgraph-1',
|
||||
pos: [0, 0],
|
||||
size: [100, 100],
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
})
|
||||
|
||||
it('uses an exported description when present', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'runtime-subgraph' } as unknown as Subgraph
|
||||
|
||||
service.registerNewSubgraph(
|
||||
subgraph,
|
||||
createExportedSubgraph({
|
||||
description: 'Reusable workflow section'
|
||||
})
|
||||
)
|
||||
|
||||
expect(mocks.addNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
description: 'Reusable workflow section'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when workflow data has no subgraph definitions', () => {
|
||||
const service = useSubgraphService()
|
||||
|
||||
service.loadSubgraphs(createWorkflow())
|
||||
|
||||
expect(mocks.addNodeDef).not.toHaveBeenCalled()
|
||||
expect(mocks.registerSubgraphNodeDef).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('registers existing root graph subgraphs from workflow data', () => {
|
||||
const service = useSubgraphService()
|
||||
const subgraph = { id: 'existing-subgraph' } as unknown as Subgraph
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
mocks.subgraphs.set('subgraph-1', subgraph)
|
||||
|
||||
service.loadSubgraphs(createWorkflow([exportedSubgraph]))
|
||||
|
||||
expect(mocks.createSubgraph).not.toHaveBeenCalled()
|
||||
expect(mocks.registerSubgraphNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1'
|
||||
}),
|
||||
subgraph,
|
||||
expect.objectContaining({
|
||||
type: 'subgraph-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('creates missing root graph subgraphs from workflow data', () => {
|
||||
const service = useSubgraphService()
|
||||
const exportedSubgraph = createExportedSubgraph()
|
||||
|
||||
service.loadSubgraphs(createWorkflow([exportedSubgraph]))
|
||||
|
||||
expect(mocks.createSubgraph).toHaveBeenCalledWith(exportedSubgraph)
|
||||
expect(mocks.registerSubgraphNodeDef).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'subgraph-1'
|
||||
}),
|
||||
expect.objectContaining({
|
||||
createdFrom: exportedSubgraph
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: 'subgraph-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
300
src/stores/assetExportStore.test.ts
Normal file
300
src/stores/assetExportStore.test.ts
Normal file
@@ -0,0 +1,300 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type * as VueUse from '@vueuse/core'
|
||||
|
||||
import type { AssetExportWsMessage } from '@/schemas/apiSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
import type { TaskId } from '@/platform/tasks/services/taskService'
|
||||
import { useAssetExportStore } from '@/stores/assetExportStore'
|
||||
|
||||
const { getExportDownloadUrl, getTask, toastAdd, intervalState } = vi.hoisted(
|
||||
() => ({
|
||||
getExportDownloadUrl: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
toastAdd: vi.fn(),
|
||||
intervalState: { cb: null as null | (() => void) }
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@vueuse/core', async (importOriginal) => ({
|
||||
...(await importOriginal<typeof VueUse>()),
|
||||
useIntervalFn: (cb: () => void) => {
|
||||
intervalState.cb = cb
|
||||
return { pause: vi.fn(), resume: vi.fn() }
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { addEventListener: vi.fn() }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: { getExportDownloadUrl }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/tasks/services/taskService', () => ({
|
||||
taskService: { getTask }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: toastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => key
|
||||
}))
|
||||
|
||||
function wsMessage(
|
||||
over: Partial<AssetExportWsMessage> = {}
|
||||
): AssetExportWsMessage {
|
||||
return {
|
||||
task_id: 'task-1',
|
||||
export_name: 'export.zip',
|
||||
assets_total: 10,
|
||||
assets_attempted: 5,
|
||||
assets_failed: 0,
|
||||
bytes_total: 1000,
|
||||
bytes_processed: 500,
|
||||
progress: 0.5,
|
||||
status: 'running',
|
||||
...over
|
||||
}
|
||||
}
|
||||
|
||||
const taskId = (id: string) => id as TaskId
|
||||
|
||||
/**
|
||||
* Build a store and an `emit` bound to the real `asset_export` listener the
|
||||
* store registers on `api`, so tests drive the state machine through its
|
||||
* actual entry point rather than a private method.
|
||||
*/
|
||||
function setup() {
|
||||
const store = useAssetExportStore()
|
||||
const entry = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find((c) => c[0] === 'asset_export')
|
||||
const handler = entry![1] as (e: { detail: AssetExportWsMessage }) => void
|
||||
const emit = (msg: AssetExportWsMessage) => handler({ detail: msg })
|
||||
// Run the polling tick that `useIntervalFn` would normally fire, and let its
|
||||
// async work settle.
|
||||
const runPoll = async () => {
|
||||
intervalState.cb?.()
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
}
|
||||
return { store, emit, runPoll }
|
||||
}
|
||||
|
||||
const STALE_AGO_MS = 20_000
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.mocked(api.addEventListener).mockClear()
|
||||
getExportDownloadUrl
|
||||
.mockReset()
|
||||
.mockResolvedValue({ url: 'https://example.com/export.zip' })
|
||||
getTask.mockReset()
|
||||
toastAdd.mockReset()
|
||||
})
|
||||
|
||||
describe('assetExportStore', () => {
|
||||
it('tracks a new export as created and is idempotent', () => {
|
||||
const { store } = setup()
|
||||
|
||||
store.trackExport(taskId('t1'))
|
||||
store.trackExport(taskId('t1'))
|
||||
|
||||
expect(store.exportList).toHaveLength(1)
|
||||
expect(store.exportList[0].status).toBe('created')
|
||||
expect(store.hasExports).toBe(true)
|
||||
expect(store.hasActiveExports).toBe(true)
|
||||
})
|
||||
|
||||
it('separates active from finished exports by status', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ task_id: 'running', status: 'running' }))
|
||||
emit(
|
||||
wsMessage({ task_id: 'failed', status: 'failed', export_name: 'f.zip' })
|
||||
)
|
||||
|
||||
expect(store.activeExports.map((e) => e.taskId)).toEqual(['running'])
|
||||
expect(store.finishedExports.map((e) => e.taskId)).toEqual(['failed'])
|
||||
})
|
||||
|
||||
it('updates an export from successive websocket messages', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ progress: 0.5, status: 'running' }))
|
||||
emit(wsMessage({ progress: 0.9, status: 'running' }))
|
||||
|
||||
expect(store.exportList).toHaveLength(1)
|
||||
expect(store.exportList[0].progress).toBe(0.9)
|
||||
})
|
||||
|
||||
it('ignores updates for an export already completed and downloaded', async () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ status: 'completed' }))
|
||||
await Promise.resolve()
|
||||
const triggeredCalls = getExportDownloadUrl.mock.calls.length
|
||||
|
||||
// A late 'running' message must not revive a completed+downloaded export
|
||||
emit(wsMessage({ status: 'running', progress: 0.1 }))
|
||||
|
||||
expect(store.exportList[0].status).toBe('completed')
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledTimes(triggeredCalls)
|
||||
})
|
||||
|
||||
it('falls back to the prior export name when a message omits it', async () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ status: 'running', progress: 0.4 }))
|
||||
emit(
|
||||
wsMessage({ status: 'running', export_name: undefined, progress: 0.6 })
|
||||
)
|
||||
|
||||
expect(store.exportList[0].exportName).toBe('export.zip')
|
||||
})
|
||||
|
||||
it('falls back to a blank export name when no message has named it', () => {
|
||||
const { store, emit } = setup()
|
||||
|
||||
emit(wsMessage({ export_name: undefined, status: 'running' }))
|
||||
|
||||
expect(store.exportList[0].exportName).toBe('')
|
||||
})
|
||||
|
||||
it('triggers a download for a named export and clears prior errors', async () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledWith('export.zip')
|
||||
expect(exp.downloadTriggered).toBe(true)
|
||||
expect(exp.downloadError).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not re-trigger a download unless forced', async () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
exp.downloadTriggered = true
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
expect(getExportDownloadUrl).not.toHaveBeenCalled()
|
||||
|
||||
await store.triggerDownload(exp, true)
|
||||
expect(getExportDownloadUrl).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('records a download error and surfaces a toast on failure', async () => {
|
||||
getExportDownloadUrl.mockRejectedValueOnce(new Error('network down'))
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(exp.downloadError).toBe('network down')
|
||||
expect(exp.downloadTriggered).toBe(false)
|
||||
expect(toastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('records a string download error', async () => {
|
||||
getExportDownloadUrl.mockRejectedValueOnce('offline')
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
const [exp] = store.exportList
|
||||
|
||||
await store.triggerDownload(exp)
|
||||
|
||||
expect(exp.downloadError).toBe('offline')
|
||||
})
|
||||
|
||||
it('clears finished exports while keeping active ones', () => {
|
||||
const { store, emit } = setup()
|
||||
emit(wsMessage({ task_id: 'a', status: 'running' }))
|
||||
emit(wsMessage({ task_id: 'b', status: 'failed', export_name: 'b.zip' }))
|
||||
|
||||
store.clearFinishedExports()
|
||||
|
||||
expect(store.exportList.map((e) => e.taskId)).toEqual(['a'])
|
||||
})
|
||||
|
||||
it('does not poll when no active export is stale', async () => {
|
||||
const { emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(getTask).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('reconciles a stale export from the task service result', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({
|
||||
status: 'completed',
|
||||
result: { export_name: 'reconciled.zip', assets_total: 10 }
|
||||
})
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(getTask).toHaveBeenCalledWith('task-1')
|
||||
expect(store.exportList[0].status).toBe('completed')
|
||||
expect(store.exportList[0].exportName).toBe('reconciled.zip')
|
||||
})
|
||||
|
||||
it('leaves a stale export active when the task is still running', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({ status: 'running' })
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0].status).toBe('running')
|
||||
})
|
||||
|
||||
it('reconciles a stale failed export using existing counters', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(
|
||||
wsMessage({
|
||||
assets_attempted: 4,
|
||||
assets_failed: 1,
|
||||
status: 'running'
|
||||
})
|
||||
)
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockResolvedValue({
|
||||
status: 'failed',
|
||||
result: { error: 'failed in result' }
|
||||
})
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0]).toMatchObject({
|
||||
assetsAttempted: 4,
|
||||
assetsFailed: 1,
|
||||
error: 'failed in result',
|
||||
status: 'failed'
|
||||
})
|
||||
})
|
||||
|
||||
it('leaves a stale export untouched when the task lookup fails', async () => {
|
||||
const { store, emit, runPoll } = setup()
|
||||
emit(wsMessage({ status: 'running' }))
|
||||
store.exportList[0].lastUpdate = Date.now() - STALE_AGO_MS
|
||||
getTask.mockRejectedValue(new Error('task not found'))
|
||||
|
||||
await runPoll()
|
||||
|
||||
expect(store.exportList[0].status).toBe('running')
|
||||
})
|
||||
})
|
||||
@@ -2,8 +2,10 @@ import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { MissingNodeType } from '@/types/comfy'
|
||||
import { createNodeExecutionId } from '@/types/nodeIdentification'
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/i18n', () => ({
|
||||
@@ -15,6 +17,53 @@ vi.mock('@/platform/distribution/types', () => ({
|
||||
}))
|
||||
|
||||
const mockShowErrorsTab = vi.hoisted(() => ({ value: false }))
|
||||
const {
|
||||
mockApp,
|
||||
mockCanvasStore,
|
||||
mockExecutionIdToNodeLocatorId,
|
||||
mockGetExecutionIdByNode,
|
||||
mockGetNodeByExecutionId,
|
||||
mockWorkflowStore
|
||||
} = vi.hoisted(() => ({
|
||||
mockApp: {
|
||||
isGraphReady: true,
|
||||
rootGraph: {}
|
||||
},
|
||||
mockCanvasStore: {
|
||||
currentGraph: undefined as object | undefined
|
||||
},
|
||||
mockExecutionIdToNodeLocatorId: vi.fn(
|
||||
(_rootGraph: unknown, id: string) => id as NodeLocatorId
|
||||
),
|
||||
mockGetExecutionIdByNode: vi.fn(),
|
||||
mockGetNodeByExecutionId: vi.fn(),
|
||||
mockWorkflowStore: {
|
||||
nodeLocatorIdToNodeId: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({ app: mockApp }))
|
||||
|
||||
vi.mock('@/renderer/core/canvas/canvasStore', () => ({
|
||||
useCanvasStore: () => mockCanvasStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: () => mockWorkflowStore
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/graphTraversalUtil', () => ({
|
||||
executionIdToNodeLocatorId: (
|
||||
...args: Parameters<typeof mockExecutionIdToNodeLocatorId>
|
||||
) => mockExecutionIdToNodeLocatorId(...args),
|
||||
forEachNode: vi.fn(),
|
||||
getExecutionIdByNode: (
|
||||
...args: Parameters<typeof mockGetExecutionIdByNode>
|
||||
) => mockGetExecutionIdByNode(...args),
|
||||
getNodeByExecutionId: (
|
||||
...args: Parameters<typeof mockGetNodeByExecutionId>
|
||||
) => mockGetNodeByExecutionId(...args)
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/settingStore', () => ({
|
||||
useSettingStore: vi.fn(() => ({
|
||||
@@ -39,6 +88,21 @@ import { useExecutionErrorStore } from './executionErrorStore'
|
||||
import { useMissingNodesErrorStore } from '@/platform/nodeReplacement/missingNodesErrorStore'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
beforeEach(() => {
|
||||
mockShowErrorsTab.value = false
|
||||
mockApp.isGraphReady = true
|
||||
mockCanvasStore.currentGraph = undefined
|
||||
mockExecutionIdToNodeLocatorId.mockImplementation(
|
||||
(_rootGraph: unknown, id: string) => id as NodeLocatorId
|
||||
)
|
||||
mockGetExecutionIdByNode.mockReset()
|
||||
mockGetNodeByExecutionId.mockReset()
|
||||
mockWorkflowStore.nodeLocatorIdToNodeId.mockImplementation(
|
||||
(locator: NodeLocatorId) =>
|
||||
toNodeId(String(locator).split(':').at(-1) ?? locator)
|
||||
)
|
||||
})
|
||||
|
||||
describe('executionErrorStore — node error operations', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
@@ -144,6 +208,31 @@ describe('executionErrorStore — node error operations', () => {
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('does nothing when the requested slot has no errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'123': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: 'Max exceeded',
|
||||
details: '',
|
||||
extra_info: { input_name: 'otherSlot' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearSimpleNodeErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testSlot'
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('preserves complex errors when slot has both simple and complex errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
@@ -388,6 +477,359 @@ describe('executionErrorStore — node error operations', () => {
|
||||
expect(store.lastNodeErrors).not.toBeNull()
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('keeps numeric range errors when no range options prove them valid', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'123': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: '...',
|
||||
details: '',
|
||||
extra_info: { input_name: 'testWidget' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearWidgetRelatedErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testWidget',
|
||||
'testWidget',
|
||||
15
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123'].errors).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('is a no-op when the target execution id has no node error entry', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastNodeErrors = {
|
||||
'999': {
|
||||
errors: [
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: '...',
|
||||
details: '',
|
||||
extra_info: { input_name: 'testWidget' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'TestNode'
|
||||
}
|
||||
}
|
||||
|
||||
store.clearWidgetRelatedErrors(
|
||||
createNodeExecutionId([toNodeId(123)]),
|
||||
'testWidget',
|
||||
'testWidget',
|
||||
15,
|
||||
{ max: 10 }
|
||||
)
|
||||
|
||||
expect(store.lastNodeErrors?.['123']).toBeUndefined()
|
||||
expect(store.lastNodeErrors?.['999'].errors).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startup clearing', () => {
|
||||
it('clears execution-start errors and closes the overlay when node errors are empty', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastNodeErrors = {}
|
||||
store.showErrorOverlay()
|
||||
|
||||
store.clearExecutionStartErrors()
|
||||
|
||||
expect(store.lastExecutionError).toBeNull()
|
||||
expect(store.lastPromptError).toBeNull()
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('keeps the overlay open when node errors remain after execution start', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.showErrorOverlay()
|
||||
|
||||
store.clearExecutionStartErrors()
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('executionErrorStore derived graph state', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
it('derives execution error node ids through locator mapping', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValue(
|
||||
fromAny<NodeLocatorId, string>('graph:7')
|
||||
)
|
||||
store.lastExecutionError = fromAny({ node_id: '7' })
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBe(toNodeId(7))
|
||||
})
|
||||
|
||||
it('returns null when there is no execution error locator', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '7' })
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValue(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null when there is no execution error', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(store.lastExecutionErrorNodeId).toBeNull()
|
||||
})
|
||||
|
||||
it('combines prompt, node, execution, and missing-node error counts', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const missingNodesStore = useMissingNodesErrorStore()
|
||||
store.lastPromptError = fromAny({ message: 'prompt failed' })
|
||||
store.lastExecutionError = fromAny({ node_id: null })
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
},
|
||||
{
|
||||
type: 'value_bigger_than_max',
|
||||
message: 'Too large',
|
||||
details: '',
|
||||
extra_info: { input_name: 'y' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
missingNodesStore.setMissingNodeTypes(
|
||||
fromAny<MissingNodeType[], unknown>([{ type: 'MissingNode', hint: '' }])
|
||||
)
|
||||
|
||||
expect(store.hasPromptError).toBe(true)
|
||||
expect(store.hasNodeError).toBe(true)
|
||||
expect(store.hasExecutionError).toBe(true)
|
||||
expect(store.hasAnyError).toBe(true)
|
||||
expect(store.allErrorExecutionIds).toEqual(['1'])
|
||||
expect(store.totalErrorCount).toBe(5)
|
||||
})
|
||||
|
||||
it('reports empty derived state when there are no errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(store.hasNodeError).toBe(false)
|
||||
expect(store.allErrorExecutionIds).toEqual([])
|
||||
expect(store.totalErrorCount).toBe(0)
|
||||
})
|
||||
|
||||
it('includes defined execution node ids in the error id list', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect(store.allErrorExecutionIds).toEqual(['2'])
|
||||
})
|
||||
|
||||
it('excludes undefined execution node ids from the error id list', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
store.lastExecutionError = fromAny({ node_id: undefined })
|
||||
|
||||
expect(store.allErrorExecutionIds).toEqual([])
|
||||
})
|
||||
|
||||
it('collects active graph node ids for validation and execution errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const activeGraph = {}
|
||||
mockCanvasStore.currentGraph = activeGraph
|
||||
mockGetNodeByExecutionId.mockImplementation((_rootGraph, id: string) => ({
|
||||
id: toNodeId(id),
|
||||
graph: activeGraph
|
||||
}))
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect([...store.activeGraphErrorNodeIds].sort()).toEqual(['1', '2'])
|
||||
})
|
||||
|
||||
it('falls back to the root graph when there is no current canvas graph', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockCanvasStore.currentGraph = undefined
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
id: toNodeId(1),
|
||||
graph: mockApp.rootGraph
|
||||
})
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
|
||||
expect([...store.activeGraphErrorNodeIds]).toEqual(['1'])
|
||||
})
|
||||
|
||||
it('ignores graph errors outside the active graph', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const activeGraph = {}
|
||||
mockCanvasStore.currentGraph = activeGraph
|
||||
mockGetNodeByExecutionId.mockReturnValue({
|
||||
id: toNodeId(1),
|
||||
graph: {}
|
||||
})
|
||||
store.lastNodeErrors = {
|
||||
'1': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
store.lastExecutionError = fromAny({ node_id: '1' })
|
||||
|
||||
expect(store.activeGraphErrorNodeIds.size).toBe(0)
|
||||
})
|
||||
|
||||
it('returns no active graph node ids before the graph is ready', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockApp.isGraphReady = false
|
||||
store.lastExecutionError = fromAny({ node_id: '2' })
|
||||
|
||||
expect(store.activeGraphErrorNodeIds.size).toBe(0)
|
||||
})
|
||||
|
||||
it('maps node errors by locator and checks slots', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const nodeError = {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
mockExecutionIdToNodeLocatorId.mockImplementation((_rootGraph, id) =>
|
||||
id === 'missing'
|
||||
? fromAny<NodeLocatorId, undefined>(undefined)
|
||||
: fromAny<NodeLocatorId, string>(`locator:${id}`)
|
||||
)
|
||||
store.lastNodeErrors = {
|
||||
'1': nodeError,
|
||||
missing: nodeError
|
||||
}
|
||||
|
||||
const locator = fromAny<NodeLocatorId, string>('locator:1')
|
||||
expect(store.getNodeErrors(locator)).toEqual(nodeError)
|
||||
expect(store.slotHasError(locator, 'x')).toBe(true)
|
||||
expect(store.slotHasError(locator, 'y')).toBe(false)
|
||||
expect(
|
||||
store.getNodeErrors(fromAny<NodeLocatorId, string>('locator:missing'))
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns no slot error when there are no node errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
|
||||
expect(
|
||||
store.slotHasError(fromAny<NodeLocatorId, string>('locator:1'), 'x')
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('detects container nodes with internal errors', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
const node = fromAny<LGraphNode, unknown>({})
|
||||
mockGetExecutionIdByNode.mockReturnValueOnce(undefined)
|
||||
|
||||
expect(store.isContainerWithInternalError(node)).toBe(false)
|
||||
|
||||
store.lastNodeErrors = {
|
||||
'1:2': {
|
||||
errors: [
|
||||
{
|
||||
type: 'required_input_missing',
|
||||
message: 'Missing',
|
||||
details: '',
|
||||
extra_info: { input_name: 'x' }
|
||||
}
|
||||
],
|
||||
dependent_outputs: [],
|
||||
class_type: 'Test'
|
||||
}
|
||||
}
|
||||
mockGetExecutionIdByNode.mockReturnValue(
|
||||
createNodeExecutionId([toNodeId(1)])
|
||||
)
|
||||
|
||||
expect(store.isContainerWithInternalError(node)).toBe(true)
|
||||
})
|
||||
|
||||
it('does not report container errors before the graph is ready', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockApp.isGraphReady = false
|
||||
|
||||
expect(
|
||||
store.isContainerWithInternalError(fromAny<LGraphNode, unknown>({}))
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -457,6 +899,23 @@ describe('surfaceMissingModels — silent option', () => {
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('does NOT open error overlay when the setting is disabled', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockShowErrorsTab.value = false
|
||||
store.surfaceMissingModels([
|
||||
fromAny({
|
||||
name: 'model.safetensors',
|
||||
nodeId: toNodeId('1'),
|
||||
nodeType: 'Loader',
|
||||
widgetName: 'ckpt',
|
||||
isMissing: true,
|
||||
isAssetSupported: false
|
||||
})
|
||||
])
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('surfaceMissingMedia — silent option', () => {
|
||||
@@ -525,6 +984,23 @@ describe('surfaceMissingMedia — silent option', () => {
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
|
||||
it('does NOT open error overlay when the setting is disabled', () => {
|
||||
const store = useExecutionErrorStore()
|
||||
mockShowErrorsTab.value = false
|
||||
store.surfaceMissingMedia([
|
||||
fromAny({
|
||||
name: 'photo.png',
|
||||
nodeId: toNodeId('1'),
|
||||
nodeType: 'LoadImage',
|
||||
widgetName: 'image',
|
||||
mediaType: 'image',
|
||||
isMissing: true
|
||||
})
|
||||
])
|
||||
|
||||
expect(store.isErrorOverlayOpen).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearAllErrors', () => {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, ref } from 'vue'
|
||||
import { createMemoryHistory, createRouter } from 'vue-router'
|
||||
|
||||
import type * as VueRouter from 'vue-router'
|
||||
|
||||
@@ -102,12 +103,24 @@ vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
function makeSubgraph(id: string): Subgraph {
|
||||
return fromPartial<Subgraph>({
|
||||
id,
|
||||
isRootGraph: false,
|
||||
rootGraph: app.rootGraph,
|
||||
_nodes: [],
|
||||
nodes: []
|
||||
})
|
||||
}
|
||||
|
||||
async function makeDuplicatedNavigationFailure(): Promise<Error> {
|
||||
const router = createRouter({
|
||||
history: createMemoryHistory(),
|
||||
routes: [{ path: '/', component: {} }]
|
||||
})
|
||||
await router.push('/')
|
||||
const failure = await router.push('/')
|
||||
if (!failure) throw new Error('Expected duplicated navigation failure')
|
||||
return failure
|
||||
}
|
||||
|
||||
async function flushHashWatcher() {
|
||||
await nextTick()
|
||||
await Promise.resolve()
|
||||
@@ -118,6 +131,7 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(app.canvas.setGraph).mockReset()
|
||||
app.rootGraph.id = ids.root
|
||||
app.rootGraph.subgraphs.clear()
|
||||
app.canvas.subgraph = undefined
|
||||
@@ -230,6 +244,44 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('does not warn when recovery redirect hits a duplicated navigation', async () => {
|
||||
routerMocks.replace.mockRejectedValueOnce(
|
||||
await makeDuplicatedNavigationFailure()
|
||||
)
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
app.canvas.graph = makeSubgraph(ids.deletedSubgraph)
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
|
||||
expect(warnSpy).not.toHaveBeenCalledWith(
|
||||
'[subgraphNavigation] router.replace rejected during recovery',
|
||||
expect.any(Error)
|
||||
)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('recovers to root when canvas is unavailable during redirect cleanup', async () => {
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
const canvas = appWithOptionalCanvas.canvas
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
try {
|
||||
routeHashRef.value = '#not-a-valid-uuid'
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('redirects when a workflow load resolves but the subgraph is still missing', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
workflowStoreState.openWorkflows = [
|
||||
@@ -304,4 +356,196 @@ describe('useSubgraphNavigationStore - navigateToHash validation', () => {
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('updateHash does nothing on initial load with an empty hash', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.replace).not.toHaveBeenCalled()
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash follows a non-empty initial subgraph hash', async () => {
|
||||
const subgraph = makeSubgraph(ids.validSubgraph)
|
||||
app.rootGraph.subgraphs.set(subgraph.id, subgraph)
|
||||
vi.mocked(app.canvas.setGraph).mockImplementation((graph) => {
|
||||
app.canvas.graph = graph
|
||||
})
|
||||
routeHashRef.value = `#${ids.validSubgraph}`
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(subgraph)
|
||||
})
|
||||
|
||||
it('updateHash does not treat the initial root hash as a subgraph', async () => {
|
||||
routeHashRef.value = `#${ids.root}`
|
||||
app.canvas.graph = app.rootGraph
|
||||
const store = useSubgraphNavigationStore()
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(workflowStoreState.activeSubgraph).toBeUndefined()
|
||||
})
|
||||
|
||||
it('updateHash replaces an empty hash and pushes the active graph id', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
expect(routerMocks.push).toHaveBeenCalledWith(`#${ids.validSubgraph}`)
|
||||
})
|
||||
|
||||
it('updateHash skips router push when hash already matches the active graph', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = `#${ids.validSubgraph}`
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash skips router push when the active graph has no id', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = '#old'
|
||||
app.canvas.graph = fromPartial<LGraph>({})
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(routerMocks.push).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updateHash warns when router push rejects', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
routerMocks.push.mockRejectedValueOnce(new Error('push failed'))
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = '#old'
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[subgraphNavigation] router.push rejected',
|
||||
expect.any(Error)
|
||||
)
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('updateHash ignores duplicated router push failures', async () => {
|
||||
const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
routerMocks.push.mockRejectedValueOnce(
|
||||
await makeDuplicatedNavigationFailure()
|
||||
)
|
||||
const store = useSubgraphNavigationStore()
|
||||
await store.updateHash()
|
||||
routeHashRef.value = `#${ids.root}`
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.validSubgraph })
|
||||
|
||||
await store.updateHash()
|
||||
|
||||
expect(warnSpy).not.toHaveBeenCalled()
|
||||
warnSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('skips workflows without active state during hash recovery', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({ path: 'inactive.json' })
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('skips workflow states and subgraphs that do not match the hash', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'other-workflow.json',
|
||||
activeState: {
|
||||
id: ids.validSubgraph,
|
||||
definitions: {
|
||||
subgraphs: [{ id: ids.validSubgraph }]
|
||||
}
|
||||
}
|
||||
})
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('handles workflow states with no subgraph definitions during recovery', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'no-definitions.json',
|
||||
activeState: { id: ids.validSubgraph }
|
||||
})
|
||||
]
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(routerMocks.replace).toHaveBeenCalledWith(`#${app.rootGraph.id}`)
|
||||
)
|
||||
})
|
||||
|
||||
it('opens a workflow and navigates to the loaded root graph', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'root-workflow.json',
|
||||
activeState: {
|
||||
id: ids.deletedSubgraph,
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
]
|
||||
workflowServiceMocks.openWorkflow.mockImplementation(async () => {
|
||||
app.rootGraph.id = ids.deletedSubgraph
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.root })
|
||||
})
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(app.canvas.setGraph).toHaveBeenCalledWith(app.rootGraph)
|
||||
)
|
||||
})
|
||||
|
||||
it('does not reset the graph when loaded workflow is already active', async () => {
|
||||
workflowStoreState.openWorkflows = [
|
||||
fromPartial<ComfyWorkflow>({
|
||||
path: 'already-active.json',
|
||||
activeState: {
|
||||
id: ids.deletedSubgraph,
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
]
|
||||
workflowServiceMocks.openWorkflow.mockImplementation(async () => {
|
||||
app.rootGraph.id = ids.deletedSubgraph
|
||||
app.canvas.graph = fromPartial<LGraph>({ id: ids.deletedSubgraph })
|
||||
})
|
||||
useSubgraphNavigationStore()
|
||||
|
||||
routeHashRef.value = `#${ids.deletedSubgraph}`
|
||||
await vi.waitFor(() =>
|
||||
expect(workflowServiceMocks.openWorkflow).toHaveBeenCalled()
|
||||
)
|
||||
|
||||
expect(app.canvas.setGraph).not.toHaveBeenCalledWith(app.rootGraph)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
@@ -136,6 +137,23 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
})
|
||||
|
||||
describe('saveViewport', () => {
|
||||
it('does not save when canvas is unavailable', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.saveViewport('root')
|
||||
|
||||
expect(store.viewportCache.has(':root')).toBe(false)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('saves viewport state for root graph', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
mockCanvas.ds.state.scale = 2
|
||||
@@ -164,6 +182,42 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
})
|
||||
|
||||
describe('restoreViewport', () => {
|
||||
it('does nothing when canvas is unavailable', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.restoreViewport('root')
|
||||
|
||||
expect(mockSetDirty).not.toHaveBeenCalled()
|
||||
expect(rafCallbacks).toHaveLength(0)
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('does not apply cached viewport when canvas disappears', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const canvas = app.canvas
|
||||
const appWithOptionalCanvas = app as unknown as {
|
||||
canvas: typeof app.canvas | undefined
|
||||
}
|
||||
store.viewportCache.set(':root', { scale: 2.5, offset: [150, 250] })
|
||||
appWithOptionalCanvas.canvas = undefined
|
||||
|
||||
try {
|
||||
store.restoreViewport('root')
|
||||
|
||||
expect(mockSetDirty).not.toHaveBeenCalled()
|
||||
} finally {
|
||||
appWithOptionalCanvas.canvas = canvas
|
||||
}
|
||||
})
|
||||
|
||||
it('restores cached viewport', () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
store.viewportCache.set(':root', { scale: 2.5, offset: [150, 250] })
|
||||
@@ -266,7 +320,10 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(mockFitView).toHaveBeenCalledOnce()
|
||||
|
||||
// User navigated away before the inner RAF fired
|
||||
mockCanvas.subgraph = { id: 'different-graph' } as never
|
||||
mockCanvas.subgraph = fromPartial<Subgraph>({
|
||||
id: 'different-graph',
|
||||
isRootGraph: false
|
||||
})
|
||||
rafCallbacks[1](performance.now())
|
||||
|
||||
expect(mockRequestSlotSyncAll).not.toHaveBeenCalled()
|
||||
@@ -283,7 +340,10 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(rafCallbacks).toHaveLength(1)
|
||||
|
||||
// Simulate graph switching away before rAF fires
|
||||
mockCanvas.subgraph = { id: 'different-graph' } as never
|
||||
mockCanvas.subgraph = fromPartial<Subgraph>({
|
||||
id: 'different-graph',
|
||||
isRootGraph: false
|
||||
})
|
||||
|
||||
rafCallbacks[0](performance.now())
|
||||
|
||||
@@ -341,6 +401,23 @@ describe('useSubgraphNavigationStore - Viewport Persistence', () => {
|
||||
expect(mockCanvas.ds.offset).toEqual([100, 100])
|
||||
})
|
||||
|
||||
it('does not save the outgoing viewport while a workflow switch is blocked', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const subgraph = fromPartial<Subgraph>({
|
||||
id: 'sub1',
|
||||
isRootGraph: false,
|
||||
rootGraph: app.rootGraph
|
||||
})
|
||||
|
||||
store.saveCurrentViewport()
|
||||
store.viewportCache.clear()
|
||||
workflowStore.activeSubgraph = subgraph
|
||||
await nextTick()
|
||||
|
||||
expect(store.viewportCache.has(':root')).toBe(false)
|
||||
})
|
||||
|
||||
it('preserves pre-existing cache entries across workflow switches', async () => {
|
||||
const store = useSubgraphNavigationStore()
|
||||
const workflowStore = useWorkflowStore()
|
||||
|
||||
@@ -10,9 +10,14 @@ import {
|
||||
import type { ExportedSubgraph } from '@/lib/litegraph/src/types/serialisation'
|
||||
import { TemplateIncludeOnDistributionEnum } from '@/platform/workflow/templates/types/template'
|
||||
import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema'
|
||||
import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import type { GlobalSubgraphData } from '@/scripts/api'
|
||||
import { api } from '@/scripts/api'
|
||||
import { app as comfyApp } from '@/scripts/app'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
import { useSubgraphStore } from '@/stores/subgraphStore'
|
||||
@@ -36,6 +41,7 @@ vi.mock('@/scripts/api', () => ({
|
||||
storeUserData: vi.fn(),
|
||||
listUserDataFullInfo: vi.fn(),
|
||||
getGlobalSubgraphs: vi.fn(),
|
||||
deleteUserData: vi.fn(),
|
||||
apiURL: vi.fn(),
|
||||
addEventListener: vi.fn()
|
||||
}
|
||||
@@ -98,6 +104,12 @@ describe('useSubgraphStore', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useSubgraphStore()
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm: vi.fn(() => true)
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should allow publishing of a subgraph', async () => {
|
||||
@@ -134,6 +146,86 @@ describe('useSubgraphStore', () => {
|
||||
await store.publishSubgraph()
|
||||
expect(api.storeUserData).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('rejects publishing when a single subgraph node is not selected', async () => {
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set()
|
||||
|
||||
await expect(store.publishSubgraph()).rejects.toThrow(
|
||||
'Must have single SubgraphNode selected to publish'
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects publishing when serialization produces multiple nodes', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize(), subgraphNode.serialize()],
|
||||
subgraphs: []
|
||||
}))
|
||||
|
||||
await expect(store.publishSubgraph()).rejects.toThrow(
|
||||
'Must have single SubgraphNode selected to publish'
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects publishing when the serialized node is not a subgraph node', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas).draw = vi.fn()
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [{ ...subgraphNode.serialize(), type: 'missing' }],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
|
||||
await expect(store.publishSubgraph('invalid')).rejects.toThrow(
|
||||
'Loaded subgraph blueprint does not contain valid subgraph'
|
||||
)
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not publish when the name prompt is cancelled', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => null),
|
||||
confirm: vi.fn(() => true)
|
||||
})
|
||||
)
|
||||
|
||||
await store.publishSubgraph()
|
||||
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not overwrite an existing blueprint when confirmation is cancelled', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => ({
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [fromAny<ExportedSubgraph, unknown>(subgraph.serialize())]
|
||||
}))
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'test'),
|
||||
confirm: vi.fn(() => false)
|
||||
})
|
||||
)
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
|
||||
await store.publishSubgraph('test')
|
||||
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should display published nodes in the node library', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
expect(
|
||||
@@ -148,6 +240,30 @@ describe('useSubgraphStore', () => {
|
||||
//check active graph
|
||||
expect(comfyApp.loadGraphData).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('switches into the nested subgraph when editing opens a wrapper graph', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const setGraph = vi.fn()
|
||||
const nested = { id: 'nested' }
|
||||
vi.mocked(comfyApp.canvas).graph = fromAny<
|
||||
NonNullable<typeof comfyApp.canvas.graph>,
|
||||
unknown
|
||||
>({
|
||||
nodes: [{ subgraph: nested }],
|
||||
setGraph
|
||||
})
|
||||
vi.mocked(comfyApp.canvas).setGraph = setGraph
|
||||
|
||||
await store.editBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(setGraph).toHaveBeenCalledWith(nested)
|
||||
})
|
||||
|
||||
it('throws when editing an unloaded blueprint', async () => {
|
||||
await expect(
|
||||
store.editBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')
|
||||
).rejects.toThrow('not yet loaded')
|
||||
})
|
||||
it('should allow subgraphs to be added to graph', async () => {
|
||||
//mock
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
@@ -166,6 +282,12 @@ describe('useSubgraphStore', () => {
|
||||
expect(second.nodes[0].id).not.toBe(-1)
|
||||
expect(second.definitions!.subgraphs![0].id).toBe('123')
|
||||
})
|
||||
|
||||
it('throws when getting an unloaded blueprint', () => {
|
||||
expect(() => store.getBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')).toThrow(
|
||||
'not yet loaded'
|
||||
)
|
||||
})
|
||||
it('should identify user blueprints as non-global', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
expect(store.isGlobalBlueprint('test')).toBe(false)
|
||||
@@ -188,6 +310,57 @@ describe('useSubgraphStore', () => {
|
||||
expect(store.isGlobalBlueprint('nonexistent')).toBe(false)
|
||||
})
|
||||
|
||||
describe('deleteBlueprint', () => {
|
||||
it('throws for unloaded blueprints', async () => {
|
||||
await expect(
|
||||
store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'missing')
|
||||
).rejects.toThrow('not yet loaded')
|
||||
})
|
||||
|
||||
it('does not delete global blueprints', async () => {
|
||||
await mockFetch(
|
||||
{},
|
||||
{
|
||||
global_bp: {
|
||||
name: 'Global Blueprint',
|
||||
info: { node_pack: 'comfy_essentials' },
|
||||
data: JSON.stringify(mockGraph)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'global_bp')
|
||||
|
||||
expect(api.deleteUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not delete when confirmation is cancelled', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm: vi.fn(() => false)
|
||||
})
|
||||
)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(api.deleteUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('deletes user blueprints after confirmation', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
vi.mocked(api.deleteUserData).mockResolvedValue({
|
||||
status: 204
|
||||
} as Response)
|
||||
|
||||
await store.deleteBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')
|
||||
|
||||
expect(api.deleteUserData).toHaveBeenCalledWith('subgraphs/test.json')
|
||||
expect(store.isUserBlueprint(BLUEPRINT_TYPE_PREFIX + 'test')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isUserBlueprint', () => {
|
||||
it('should return true for user blueprints', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
@@ -285,6 +458,206 @@ describe('useSubgraphStore', () => {
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('continues when global blueprint discovery rejects', async () => {
|
||||
vi.mocked(api.listUserDataFullInfo).mockResolvedValue([])
|
||||
vi.mocked(api.getGlobalSubgraphs).mockRejectedValue(
|
||||
new Error('global down')
|
||||
)
|
||||
|
||||
await store.fetchSubgraphs()
|
||||
|
||||
expect(store.subgraphBlueprints).toEqual([])
|
||||
})
|
||||
|
||||
it('reports compact detail when more than three blueprints fail', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
await mockFetch(
|
||||
{},
|
||||
{
|
||||
a: { name: 'A', info: { node_pack: 'test' }, data: '' },
|
||||
b: { name: 'B', info: { node_pack: 'test' }, data: '' },
|
||||
c: { name: 'C', info: { node_pack: 'test' }, data: '' },
|
||||
d: { name: 'D', info: { node_pack: 'test' }, data: '' }
|
||||
}
|
||||
)
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledTimes(4)
|
||||
expect(useToastStore().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'x4' })
|
||||
)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('ignores invalid user blueprint files during fetch', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'invalid.json': {
|
||||
nodes: [],
|
||||
definitions: { subgraphs: [] }
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects loaded blueprints whose wrapper node does not reference a subgraph', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'invalid-ref.json': {
|
||||
nodes: [{ id: 1, type: 'missing' }],
|
||||
definitions: { subgraphs: [{ id: 'present' }] }
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects loaded blueprints without subgraph definitions', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
await mockFetch({
|
||||
'missing-definitions.json': {
|
||||
nodes: [{ id: 1, type: 'missing' }]
|
||||
}
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Failed to load subgraph blueprint',
|
||||
expect.any(Error)
|
||||
)
|
||||
expect(store.subgraphBlueprints).toHaveLength(0)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects saving a blueprint whose active state has no subgraph definitions', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint?.changeTracker) throw new Error('Blueprint was not loaded')
|
||||
blueprint.changeTracker!.activeState = fromAny<ComfyWorkflowJSON, unknown>({
|
||||
nodes: [{ id: 1, type: '123' }]
|
||||
})
|
||||
|
||||
await expect(blueprint.save()).rejects.toThrow(
|
||||
'The root graph of a subgraph blueprint must consist of only a single subgraph node'
|
||||
)
|
||||
})
|
||||
|
||||
it('marks non-blueprint root nodes when saving an invalid blueprint', async () => {
|
||||
vi.mocked(comfyApp.canvas).draw = vi.fn()
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint?.changeTracker) throw new Error('Blueprint was not loaded')
|
||||
blueprint.changeTracker!.activeState = fromAny<ComfyWorkflowJSON, unknown>({
|
||||
nodes: [
|
||||
{ id: 1, type: '123' },
|
||||
{ id: 2, type: 'OtherNode' }
|
||||
],
|
||||
definitions: { subgraphs: [{ id: '123' }] }
|
||||
})
|
||||
|
||||
await expect(blueprint.save()).rejects.toThrow(
|
||||
'The root graph of a subgraph blueprint must consist of only a single subgraph node'
|
||||
)
|
||||
expect(comfyApp.canvas.draw).toHaveBeenCalledWith(true, true)
|
||||
})
|
||||
|
||||
it('does not save a loaded blueprint when first-save confirmation is cancelled', async () => {
|
||||
const confirm = vi.fn(() => false)
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm
|
||||
})
|
||||
)
|
||||
useSettingStore().settingValues['Comfy.Workflow.WarnBlueprintOverwrite'] =
|
||||
true
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
const result = await blueprint.save()
|
||||
|
||||
expect(result).toBe(blueprint)
|
||||
expect(confirm).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'overwriteBlueprint',
|
||||
itemList: ['test']
|
||||
})
|
||||
)
|
||||
expect(api.storeUserData).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('saves a loaded blueprint after first-save confirmation', async () => {
|
||||
const confirm = vi.fn(() => true)
|
||||
vi.mocked(useDialogService).mockReturnValue(
|
||||
fromPartial<ReturnType<typeof useDialogService>>({
|
||||
prompt: vi.fn(() => 'testname'),
|
||||
confirm
|
||||
})
|
||||
)
|
||||
useSettingStore().settingValues['Comfy.Workflow.WarnBlueprintOverwrite'] =
|
||||
true
|
||||
vi.mocked(api.storeUserData).mockResolvedValue({
|
||||
status: 200,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
path: 'subgraphs/test.json',
|
||||
modified: Date.now(),
|
||||
size: 2
|
||||
})
|
||||
} as Response)
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
await blueprint.save()
|
||||
|
||||
const [path, data, options] = vi.mocked(api.storeUserData).mock.calls[0]
|
||||
if (typeof data !== 'string') throw new Error('Expected saved JSON')
|
||||
expect(path).toBe('subgraphs/test.json')
|
||||
expect(JSON.parse(data)).toMatchObject({
|
||||
nodes: [{ type: '123', title: 'test' }],
|
||||
definitions: { subgraphs: [{ id: '123', name: 'test' }] }
|
||||
})
|
||||
expect(options).toEqual({
|
||||
overwrite: true,
|
||||
throwOnError: true,
|
||||
full_info: true
|
||||
})
|
||||
})
|
||||
|
||||
it('returns an already-loaded blueprint when loading without force', async () => {
|
||||
await mockFetch({ 'test.json': mockGraph })
|
||||
const blueprint = useWorkflowStore().getWorkflowByPath(
|
||||
'subgraphs/test.json'
|
||||
)
|
||||
if (!blueprint) throw new Error('Blueprint was not loaded')
|
||||
|
||||
await blueprint.load()
|
||||
|
||||
expect(api.getUserData).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle global blueprint with rejected data promise gracefully', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
await mockFetch(
|
||||
@@ -406,6 +779,29 @@ describe('useSubgraphStore', () => {
|
||||
expect(nodeDef?.description).toBe('This is a test blueprint')
|
||||
})
|
||||
|
||||
it('does not copy workflowRendererVersion into subgraph metadata on load', async () => {
|
||||
await mockFetch({
|
||||
'metadata-load.json': {
|
||||
nodes: [{ type: '123' }],
|
||||
definitions: {
|
||||
subgraphs: [{ id: '123', extra: {} }]
|
||||
},
|
||||
extra: {
|
||||
BlueprintDescription: 'Loaded description',
|
||||
workflowRendererVersion: 'Vue'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const blueprint = store.getBlueprint(
|
||||
BLUEPRINT_TYPE_PREFIX + 'metadata-load'
|
||||
)
|
||||
|
||||
expect(blueprint.definitions!.subgraphs![0].extra).toEqual({
|
||||
BlueprintDescription: 'Loaded description'
|
||||
})
|
||||
})
|
||||
|
||||
it('should not duplicate metadata in both workflow extra and subgraph extra when publishing', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
@@ -465,6 +861,59 @@ describe('useSubgraphStore', () => {
|
||||
expect(subgraphExtra?.BlueprintDescription).toBeUndefined()
|
||||
expect(subgraphExtra?.BlueprintSearchAliases).toBeUndefined()
|
||||
})
|
||||
|
||||
it('keeps workflowRendererVersion in subgraph extra when publishing', async () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const subgraphNode = createTestSubgraphNode(subgraph)
|
||||
subgraphNode.graph!.add(subgraphNode)
|
||||
|
||||
subgraph.extra = {
|
||||
BlueprintDescription: 'Test description',
|
||||
workflowRendererVersion: 'Vue'
|
||||
}
|
||||
|
||||
vi.mocked(comfyApp.canvas).selectedItems = new Set([subgraphNode])
|
||||
vi.mocked(comfyApp.canvas)._serializeItems = vi.fn(() => {
|
||||
const serializedSubgraph = fromPartial<ExportedSubgraph>({
|
||||
...subgraph.serialize(),
|
||||
links: [],
|
||||
groups: [],
|
||||
version: 1
|
||||
})
|
||||
return {
|
||||
nodes: [subgraphNode.serialize()],
|
||||
subgraphs: [serializedSubgraph]
|
||||
}
|
||||
})
|
||||
|
||||
let savedWorkflowData: Record<string, unknown> | null = null
|
||||
vi.mocked(api.storeUserData).mockImplementation(async (_path, data) => {
|
||||
savedWorkflowData = JSON.parse(data as string)
|
||||
return {
|
||||
status: 200,
|
||||
json: () =>
|
||||
Promise.resolve({
|
||||
path: 'subgraphs/testname.json',
|
||||
modified: Date.now(),
|
||||
size: 2
|
||||
})
|
||||
} as Response
|
||||
})
|
||||
|
||||
await mockFetch({ 'testname.json': mockGraph })
|
||||
await store.publishSubgraph()
|
||||
|
||||
expect(savedWorkflowData).not.toBeNull()
|
||||
expect(savedWorkflowData!.extra).toEqual({
|
||||
BlueprintDescription: 'Test description'
|
||||
})
|
||||
const definitions = savedWorkflowData!.definitions as {
|
||||
subgraphs: Array<{ extra?: Record<string, unknown> }>
|
||||
}
|
||||
expect(definitions.subgraphs[0]?.extra?.workflowRendererVersion).toBe(
|
||||
'Vue'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('subgraph definition category', () => {
|
||||
|
||||
@@ -8,10 +8,12 @@ import {
|
||||
hexToRgb,
|
||||
hsbToRgb,
|
||||
hsvaToHex,
|
||||
isColorFormat,
|
||||
isTransparent,
|
||||
luminance,
|
||||
parseToRgb,
|
||||
readableTextColor,
|
||||
rgbToHsl,
|
||||
rgbToHex,
|
||||
textOnColor,
|
||||
toHexFromFormat
|
||||
@@ -236,6 +238,24 @@ describe('colorUtil conversions', () => {
|
||||
it('parses 8-digit hex and ignores the alpha channel in RGB output', () => {
|
||||
expect(parseToRgb('#ff000080')).toEqual({ r: 255, g: 0, b: 0 })
|
||||
})
|
||||
|
||||
it('covers HSL hue sectors outside primary colors', () => {
|
||||
expect(parseToRgb('hsl(60, 100%, 50%)')).toEqual({
|
||||
r: 255,
|
||||
g: 255,
|
||||
b: 0
|
||||
})
|
||||
expect(parseToRgb('hsl(180, 100%, 50%)')).toEqual({
|
||||
r: 0,
|
||||
g: 255,
|
||||
b: 255
|
||||
})
|
||||
expect(parseToRgb('hsl(300, 100%, 50%)')).toEqual({
|
||||
r: 255,
|
||||
g: 0,
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('hsbToRgb normalization', () => {
|
||||
@@ -247,6 +267,24 @@ describe('colorUtil conversions', () => {
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
|
||||
it('covers all hue ranges', () => {
|
||||
expect(hsbToRgb({ h: 90, s: 100, b: 100 })).toEqual({
|
||||
r: 127,
|
||||
g: 255,
|
||||
b: 0
|
||||
})
|
||||
expect(hsbToRgb({ h: 180, s: 100, b: 100 })).toEqual({
|
||||
r: 0,
|
||||
g: 255,
|
||||
b: 255
|
||||
})
|
||||
expect(hsbToRgb({ h: 300, s: 100, b: 100 })).toEqual({
|
||||
r: 255,
|
||||
g: 0,
|
||||
b: 255
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isTransparent', () => {
|
||||
@@ -263,12 +301,27 @@ describe('colorUtil conversions', () => {
|
||||
})
|
||||
|
||||
it('returns false for fully opaque hex colors', () => {
|
||||
expect(isTransparent('red')).toBe(false)
|
||||
expect(isTransparent('#ff0000')).toBe(false)
|
||||
expect(isTransparent('#ff0000ff')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('toHexFromFormat', () => {
|
||||
it('normalizes supported hex spellings', () => {
|
||||
expect(toHexFromFormat('', 'hex')).toBe('#000000')
|
||||
expect(toHexFromFormat('abc', 'hex')).toBe('#abc')
|
||||
expect(toHexFromFormat('#abc', 'hex')).toBe('#abc')
|
||||
expect(toHexFromFormat('#abcdef', 'hex')).toBe('#abcdef')
|
||||
expect(toHexFromFormat('abcdef12', 'hex')).toBe('#abcdef12')
|
||||
expect(toHexFromFormat('#abcdef12', 'hex')).toBe('#abcdef12')
|
||||
})
|
||||
|
||||
it('converts rgb strings and rejects non-string rgb input', () => {
|
||||
expect(toHexFromFormat('rgb(255, 128, 0)', 'rgb')).toBe('#ff8000')
|
||||
expect(toHexFromFormat({ r: 255, g: 128, b: 0 }, 'rgb')).toBe('#000000')
|
||||
})
|
||||
|
||||
it('treats an HSV object (with v field) the same as an HSB object', () => {
|
||||
const hsbObject = { h: 120, s: 100, b: 100 }
|
||||
const hsvObject = { h: 120, s: 100, v: 100 }
|
||||
@@ -281,6 +334,7 @@ describe('colorUtil conversions', () => {
|
||||
|
||||
it('returns #000000 for unparseable hsb input', () => {
|
||||
expect(toHexFromFormat({ h: 0 }, 'hsb')).toBe('#000000')
|
||||
expect(toHexFromFormat('hsb()', 'hsb')).toBe('#000000')
|
||||
})
|
||||
|
||||
it('prefixes a bare 6-digit hex with #', () => {
|
||||
@@ -295,6 +349,22 @@ describe('colorUtil conversions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('rgbToHsl', () => {
|
||||
it('handles light colors and wrapped red hue', () => {
|
||||
expect(rgbToHsl({ r: 255, g: 128, b: 128 }).s).toBeCloseTo(1)
|
||||
expect(rgbToHsl({ r: 255, g: 0, b: 128 }).h).toBeCloseTo(0.916, 2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isColorFormat', () => {
|
||||
it('recognizes public color format ids', () => {
|
||||
expect(isColorFormat('hex')).toBe(true)
|
||||
expect(isColorFormat('rgb')).toBe(true)
|
||||
expect(isColorFormat('hsb')).toBe(true)
|
||||
expect(isColorFormat('hsl')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('readableTextColor / textOnColor', () => {
|
||||
it('lightens dark colors', () => {
|
||||
expect(readableTextColor('#000000')).not.toBe('rgb(0,0,0)')
|
||||
@@ -359,6 +429,10 @@ describe('colorUtil - adjustColor', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('returns the original value when no adjustments are requested', () => {
|
||||
expect(adjustColor('#123456', {})).toBe('#123456')
|
||||
})
|
||||
|
||||
it('treats 5-char hex as valid color with alpha', () => {
|
||||
const result = adjustColor('#f008', {
|
||||
lightness: targetLightness,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user