mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 05:38:26 +00:00
Compare commits
4 Commits
feat/creat
...
codex/cove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98b31d1c71 | ||
|
|
31d1d8d84d | ||
|
|
2ec2a0e091 | ||
|
|
9cf5c9a93f |
@@ -15,7 +15,7 @@ const { categories } = defineProps<{
|
||||
|
||||
const activeSection = ref(categories[0]?.value ?? '')
|
||||
|
||||
const HEADER_OFFSET = -144
|
||||
const HEADER_OFFSET_PX = -144
|
||||
const BOTTOM_THRESHOLD_PX = 4
|
||||
const SCROLL_SAFETY_MS = 1500
|
||||
|
||||
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
|
||||
const el = document.getElementById(id)
|
||||
if (el) {
|
||||
scrollTo(el, {
|
||||
offset: HEADER_OFFSET,
|
||||
offset: HEADER_OFFSET_PX,
|
||||
duration: 0.8,
|
||||
immediate: prefersReducedMotion(),
|
||||
onComplete: clearScrollLock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<li
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
|
||||
>
|
||||
<slot />
|
||||
</li>
|
||||
|
||||
@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
|
||||
const subscriptionDialog = useSubscriptionDialog()
|
||||
|
||||
function handleClick() {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComputedRef, Ref } from 'vue'
|
||||
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type {
|
||||
BillingStatus,
|
||||
@@ -75,9 +76,10 @@ export interface BillingActions {
|
||||
*/
|
||||
requireActiveSubscription: () => Promise<void>
|
||||
/**
|
||||
* Shows the subscription dialog.
|
||||
* Shows the subscription dialog. Pass a reason so the paywall open and any
|
||||
* downstream checkout stay attributed to the triggering product moment.
|
||||
*/
|
||||
showSubscriptionDialog: () => void
|
||||
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
|
||||
}
|
||||
|
||||
export interface BillingState {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getTierFeatures
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
PreviewSubscribeOptions,
|
||||
SubscribeOptions
|
||||
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
|
||||
return activeContext.value.requireActiveSubscription()
|
||||
}
|
||||
|
||||
function showSubscriptionDialog() {
|
||||
return activeContext.value.showSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
|
||||
return activeContext.value.showSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
|
||||
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingStatus,
|
||||
BillingSubscriptionStatus,
|
||||
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
legacyShowSubscriptionDialog()
|
||||
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
legacyShowSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
legacyShowSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,6 @@ function handleClose() {
|
||||
}
|
||||
|
||||
function handleSubscribe() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
|
||||
}
|
||||
</script>
|
||||
|
||||
85
src/platform/assets/composables/media/useAssetsApi.test.ts
Normal file
85
src/platform/assets/composables/media/useAssetsApi.test.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useAssetsApi } from './useAssetsApi'
|
||||
|
||||
const mockAssetsStore = vi.hoisted(() => ({
|
||||
inputAssets: [] as AssetItem[],
|
||||
historyAssets: [] as AssetItem[],
|
||||
inputLoading: false,
|
||||
historyLoading: false,
|
||||
inputError: null as string | null,
|
||||
historyError: null as string | null,
|
||||
hasMoreHistory: false,
|
||||
isLoadingMore: false,
|
||||
updateInputs: vi.fn(),
|
||||
updateHistory: vi.fn(),
|
||||
loadMoreHistory: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => mockAssetsStore
|
||||
}))
|
||||
|
||||
function createAsset(id: string): AssetItem {
|
||||
return {
|
||||
id,
|
||||
name: `${id}.png`,
|
||||
size: 1,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
tags: ['input']
|
||||
}
|
||||
}
|
||||
|
||||
describe('useAssetsApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAssetsStore.inputAssets = [createAsset('input-1')]
|
||||
mockAssetsStore.historyAssets = [createAsset('history-1')]
|
||||
mockAssetsStore.inputLoading = true
|
||||
mockAssetsStore.historyLoading = false
|
||||
mockAssetsStore.inputError = 'input-error'
|
||||
mockAssetsStore.historyError = 'history-error'
|
||||
mockAssetsStore.hasMoreHistory = true
|
||||
mockAssetsStore.isLoadingMore = true
|
||||
})
|
||||
|
||||
it('uses input assets and refreshes inputs', async () => {
|
||||
const api = useAssetsApi('input')
|
||||
|
||||
expect(api.media.value).toEqual([createAsset('input-1')])
|
||||
expect(api.loading.value).toBe(true)
|
||||
expect(api.error.value).toBe('input-error')
|
||||
expect(api.hasMore.value).toBe(false)
|
||||
expect(api.isLoadingMore.value).toBe(false)
|
||||
|
||||
await expect(api.fetchMediaList()).resolves.toEqual([
|
||||
createAsset('input-1')
|
||||
])
|
||||
await expect(api.refresh()).resolves.toEqual([createAsset('input-1')])
|
||||
await api.loadMore()
|
||||
|
||||
expect(mockAssetsStore.updateInputs).toHaveBeenCalledTimes(2)
|
||||
expect(mockAssetsStore.updateHistory).not.toHaveBeenCalled()
|
||||
expect(mockAssetsStore.loadMoreHistory).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses output history and loads more history', async () => {
|
||||
const api = useAssetsApi('output')
|
||||
|
||||
expect(api.media.value).toEqual([createAsset('history-1')])
|
||||
expect(api.loading.value).toBe(false)
|
||||
expect(api.error.value).toBe('history-error')
|
||||
expect(api.hasMore.value).toBe(true)
|
||||
expect(api.isLoadingMore.value).toBe(true)
|
||||
|
||||
await expect(api.fetchMediaList()).resolves.toEqual([
|
||||
createAsset('history-1')
|
||||
])
|
||||
await api.loadMore()
|
||||
|
||||
expect(mockAssetsStore.updateHistory).toHaveBeenCalledOnce()
|
||||
expect(mockAssetsStore.updateInputs).not.toHaveBeenCalled()
|
||||
expect(mockAssetsStore.loadMoreHistory).toHaveBeenCalledOnce()
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import { useI18n } from 'vue-i18n'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { MediaAssetKey } from '@/platform/assets/schemas/mediaAssetSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetMeta } from '@/platform/assets/schemas/mediaAssetSchema'
|
||||
import type * as outputAssetUtilModule from '../utils/outputAssetUtil'
|
||||
@@ -18,6 +19,12 @@ const mockIsCloud = vi.hoisted(() => ({ value: false }))
|
||||
|
||||
// Track the filename passed to createAnnotatedPath
|
||||
const capturedFilenames = vi.hoisted(() => ({ values: [] as string[] }))
|
||||
const capturedAnnotatedPaths = vi.hoisted(() => ({
|
||||
values: [] as Array<{
|
||||
item: { filename: string; subfolder?: string; type?: string }
|
||||
options: { rootFolder?: string }
|
||||
}>
|
||||
}))
|
||||
|
||||
const mockDownloadFile = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/base/common/downloadUtil', () => ({
|
||||
@@ -73,9 +80,10 @@ vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({})
|
||||
}))
|
||||
|
||||
const mockCopyToClipboard = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/composables/useCopyToClipboard', () => ({
|
||||
useCopyToClipboard: () => ({
|
||||
copyToClipboard: vi.fn()
|
||||
copyToClipboard: mockCopyToClipboard
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -93,45 +101,50 @@ vi.mock('@/platform/workflow/utils/workflowExtractionUtil', () => ({
|
||||
extractWorkflowFromAsset: mockExtractWorkflowFromAsset
|
||||
}))
|
||||
|
||||
const mockAddNodeOnGraph = vi.hoisted(() => vi.fn())
|
||||
const mockGetCanvasCenter = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({
|
||||
addNodeOnGraph: vi.fn().mockReturnValue(
|
||||
fromAny<LGraphNode, unknown>({
|
||||
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
|
||||
graph: { setDirtyCanvas: vi.fn() }
|
||||
})
|
||||
),
|
||||
getCanvasCenter: vi.fn().mockReturnValue([100, 100])
|
||||
addNodeOnGraph: mockAddNodeOnGraph,
|
||||
getCanvasCenter: mockGetCanvasCenter
|
||||
})
|
||||
}))
|
||||
|
||||
const mockNodeDefsByName = vi.hoisted(() => ({
|
||||
value: {
|
||||
LoadImage: {
|
||||
name: 'LoadImage',
|
||||
display_name: 'Load Image'
|
||||
}
|
||||
} as Record<string, unknown>
|
||||
}))
|
||||
vi.mock('@/stores/nodeDefStore', () => ({
|
||||
useNodeDefStore: () => ({
|
||||
nodeDefsByName: {
|
||||
LoadImage: {
|
||||
name: 'LoadImage',
|
||||
display_name: 'Load Image'
|
||||
}
|
||||
}
|
||||
nodeDefsByName: mockNodeDefsByName.value
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/createAnnotatedPath', () => ({
|
||||
createAnnotatedPath: vi.fn((item: { filename: string }) => {
|
||||
capturedFilenames.values.push(item.filename)
|
||||
return item.filename
|
||||
})
|
||||
createAnnotatedPath: vi.fn(
|
||||
(
|
||||
item: { filename: string; subfolder?: string; type?: string },
|
||||
options: { rootFolder?: string }
|
||||
) => {
|
||||
capturedAnnotatedPaths.values.push({ item, options })
|
||||
capturedFilenames.values.push(item.filename)
|
||||
return item.filename
|
||||
}
|
||||
)
|
||||
}))
|
||||
|
||||
const mockDetectNodeTypeFromFilename = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/utils/loaderNodeUtil', () => ({
|
||||
detectNodeTypeFromFilename: vi.fn().mockReturnValue({
|
||||
nodeType: 'LoadImage',
|
||||
widgetName: 'image'
|
||||
})
|
||||
detectNodeTypeFromFilename: mockDetectNodeTypeFromFilename
|
||||
}))
|
||||
|
||||
const mockIsResultItemType = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/utils/typeGuardUtil', () => ({
|
||||
isResultItemType: vi.fn().mockReturnValue(true)
|
||||
isResultItemType: mockIsResultItemType
|
||||
}))
|
||||
|
||||
const mockGetAssetType = vi.hoisted(() => vi.fn())
|
||||
@@ -186,7 +199,9 @@ vi.mock('@/scripts/api', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
const mockAppGraph = vi.hoisted(() => ({ value: { _nodes: [] as unknown[] } }))
|
||||
const mockAppGraph = vi.hoisted(() => ({
|
||||
value: { _nodes: [] as unknown[] } as { _nodes: unknown[] } | null
|
||||
}))
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
get graph() {
|
||||
@@ -291,7 +306,43 @@ describe('useMediaAssetActions', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
capturedFilenames.values = []
|
||||
capturedAnnotatedPaths.values = []
|
||||
mockIsCloud.value = false
|
||||
mockAppGraph.value = { _nodes: [] }
|
||||
mockDownloadFile.mockReset()
|
||||
mockCopyToClipboard.mockReset()
|
||||
mockShowDialog.mockReset()
|
||||
mockAddNodeOnGraph.mockReset()
|
||||
mockAddNodeOnGraph.mockReturnValue(
|
||||
fromAny<LGraphNode, unknown>({
|
||||
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
|
||||
graph: { setDirtyCanvas: vi.fn() }
|
||||
})
|
||||
)
|
||||
mockGetCanvasCenter.mockReset()
|
||||
mockGetCanvasCenter.mockReturnValue([100, 100])
|
||||
mockNodeDefsByName.value = {
|
||||
LoadImage: {
|
||||
name: 'LoadImage',
|
||||
display_name: 'Load Image'
|
||||
}
|
||||
}
|
||||
mockDetectNodeTypeFromFilename.mockReset()
|
||||
mockDetectNodeTypeFromFilename.mockReturnValue({
|
||||
nodeType: 'LoadImage',
|
||||
widgetName: 'image'
|
||||
})
|
||||
mockIsResultItemType.mockReset()
|
||||
mockIsResultItemType.mockReturnValue(true)
|
||||
mockExtractWorkflowFromAsset.mockReset()
|
||||
mockOpenWorkflowAction.mockReset()
|
||||
mockExportWorkflowAction.mockReset()
|
||||
mockCreateAssetExport.mockReset()
|
||||
mockCreateAssetExport.mockResolvedValue({
|
||||
task_id: 'test-task-id',
|
||||
status: 'pending'
|
||||
})
|
||||
mockDeleteAsset.mockReset()
|
||||
mockGetOutputAssetMetadata.mockReset()
|
||||
mockGetOutputAssetMetadata.mockReturnValue(null)
|
||||
mockGetAssetType.mockReset()
|
||||
@@ -299,7 +350,139 @@ describe('useMediaAssetActions', () => {
|
||||
mockResolveOutputAssetItems.mockResolvedValue([])
|
||||
})
|
||||
|
||||
describe('copyJobId', () => {
|
||||
it('does nothing when no asset is available', async () => {
|
||||
const { actions, unmount } = mountMediaActions()
|
||||
|
||||
await actions.copyJobId()
|
||||
|
||||
expect(mockCopyToClipboard).not.toHaveBeenCalled()
|
||||
expect(useToast().add).not.toHaveBeenCalled()
|
||||
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('warns when the asset has no job id', async () => {
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.copyJobId(createMockAsset())
|
||||
|
||||
expect(mockCopyToClipboard).not.toHaveBeenCalled()
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
|
||||
it('copies the metadata job id when present', async () => {
|
||||
mockGetOutputAssetMetadata.mockReturnValue({ jobId: 'job-from-meta' })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.copyJobId(createMockAsset())
|
||||
|
||||
expect(mockCopyToClipboard).toHaveBeenCalledWith('job-from-meta')
|
||||
})
|
||||
|
||||
it('copies the output asset id when metadata omits the job id', async () => {
|
||||
mockGetAssetType.mockReturnValue('output')
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.copyJobId(createMockAsset({ id: 'history-id' }))
|
||||
|
||||
expect(mockCopyToClipboard).toHaveBeenCalledWith('history-id')
|
||||
})
|
||||
})
|
||||
|
||||
describe('addWorkflow', () => {
|
||||
it('does nothing when no asset is available', async () => {
|
||||
const { actions, unmount } = mountMediaActions()
|
||||
|
||||
await actions.addWorkflow()
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
expect(useToast().add).not.toHaveBeenCalled()
|
||||
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('uses the injected media asset when no explicit asset is provided', async () => {
|
||||
const mediaAsset = createMockMediaAsset({ name: 'context-image.png' })
|
||||
const { actions, unmount } = mountMediaActions(mediaAsset)
|
||||
|
||||
await actions.addWorkflow()
|
||||
|
||||
expect(capturedFilenames.values).toContain('context-image.png')
|
||||
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('warns when the filename has no compatible loader node', async () => {
|
||||
mockDetectNodeTypeFromFilename.mockReturnValue({
|
||||
nodeType: undefined,
|
||||
widgetName: undefined
|
||||
})
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addWorkflow(createMockAsset({ name: 'notes.txt' }))
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
|
||||
it('reports missing node definitions', async () => {
|
||||
mockNodeDefsByName.value = {}
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addWorkflow(createMockAsset())
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('reports loader-node creation failure', async () => {
|
||||
mockAddNodeOnGraph.mockReturnValue(undefined)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addWorkflow(createMockAsset())
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('still adds the node when the expected widget is absent', async () => {
|
||||
const setDirtyCanvas = vi.fn()
|
||||
mockAddNodeOnGraph.mockReturnValue(
|
||||
fromAny<LGraphNode, unknown>({
|
||||
widgets: [{ name: 'other', value: '' }],
|
||||
graph: { setDirtyCanvas }
|
||||
})
|
||||
)
|
||||
mockGetOutputAssetMetadata.mockReturnValue({ subfolder: 'nested' })
|
||||
mockGetAssetType.mockReturnValue('custom')
|
||||
mockIsResultItemType.mockReturnValue(false)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addWorkflow(createMockAsset({ name: 'asset.png' }))
|
||||
|
||||
expect(capturedAnnotatedPaths.values.at(-1)).toEqual({
|
||||
item: {
|
||||
filename: 'asset.png',
|
||||
subfolder: 'nested',
|
||||
type: undefined
|
||||
},
|
||||
options: { rootFolder: 'input' }
|
||||
})
|
||||
expect(setDirtyCanvas).toHaveBeenCalledWith(true, true)
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
describe('OSS mode (isCloud = false)', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = false
|
||||
@@ -366,6 +549,83 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
|
||||
describe('addMultipleToWorkflow', () => {
|
||||
it('does nothing for an empty selection', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addMultipleToWorkflow([])
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
expect(useToast().add).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows a failure toast when none of the selected assets can be added', async () => {
|
||||
mockDetectNodeTypeFromFilename
|
||||
.mockReturnValueOnce({ nodeType: undefined, widgetName: undefined })
|
||||
.mockReturnValueOnce({ nodeType: 'MissingNode', widgetName: 'image' })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addMultipleToWorkflow([
|
||||
createMockAsset({ id: 'a', name: 'unsupported.txt' }),
|
||||
createMockAsset({ id: 'b', name: 'missing.png' })
|
||||
])
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('shows a partial warning when only some nodes are added', async () => {
|
||||
mockAddNodeOnGraph
|
||||
.mockReturnValueOnce(
|
||||
fromAny<LGraphNode, unknown>({
|
||||
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
|
||||
graph: { setDirtyCanvas: vi.fn() }
|
||||
})
|
||||
)
|
||||
.mockReturnValueOnce(undefined)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addMultipleToWorkflow([
|
||||
createMockAsset({ id: 'a', name: 'a.png' }),
|
||||
createMockAsset({ id: 'b', name: 'b.png' })
|
||||
])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
|
||||
it('adds assets without a matching widget using untyped paths', async () => {
|
||||
const setDirtyCanvas = vi.fn()
|
||||
mockAddNodeOnGraph.mockReturnValue(
|
||||
fromAny<LGraphNode, unknown>({
|
||||
widgets: [{ name: 'other', value: '' }],
|
||||
graph: { setDirtyCanvas }
|
||||
})
|
||||
)
|
||||
mockGetAssetType.mockReturnValue('custom')
|
||||
mockIsResultItemType.mockReturnValue(false)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.addMultipleToWorkflow([
|
||||
createMockAsset({ id: 'asset-1', name: 'asset-1.png' })
|
||||
])
|
||||
|
||||
expect(capturedAnnotatedPaths.values.at(-1)).toEqual({
|
||||
item: {
|
||||
filename: 'asset-1.png',
|
||||
subfolder: '',
|
||||
type: undefined
|
||||
},
|
||||
options: { rootFolder: undefined }
|
||||
})
|
||||
expect(setDirtyCanvas).toHaveBeenCalledWith(true, true)
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
describe('Cloud mode (isCloud = true)', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = true
|
||||
@@ -397,10 +657,56 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('openWorkflow', () => {
|
||||
beforeEach(() => {
|
||||
mockExtractWorkflowFromAsset.mockResolvedValue({
|
||||
workflow: { version: 0.4 },
|
||||
filename: 'workflow.json'
|
||||
})
|
||||
})
|
||||
|
||||
it('does nothing when no asset is available', async () => {
|
||||
const { actions, unmount } = mountMediaActions()
|
||||
|
||||
await actions.openWorkflow()
|
||||
|
||||
expect(mockExtractWorkflowFromAsset).not.toHaveBeenCalled()
|
||||
expect(mockOpenWorkflowAction).not.toHaveBeenCalled()
|
||||
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('shows a success toast after opening the workflow', async () => {
|
||||
mockOpenWorkflowAction.mockResolvedValue({ success: true })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openWorkflow(createMockAsset())
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the fallback warning when opening returns no error message', async () => {
|
||||
mockOpenWorkflowAction.mockResolvedValue({ success: false })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openWorkflow(createMockAsset())
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'warn',
|
||||
detail: 'mediaAsset.noWorkflowDataFound'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('exportWorkflow', () => {
|
||||
const successResult = { success: true } as const
|
||||
const cancelledResult = { success: false, cancelled: true } as const
|
||||
const failureResult = { success: false, error: 'boom' } as const
|
||||
const failureWithoutError = { success: false } as const
|
||||
const noWorkflowResult = {
|
||||
success: false,
|
||||
error: 'No workflow data available'
|
||||
@@ -455,6 +761,31 @@ describe('useMediaAssetActions', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when no asset is available', async () => {
|
||||
const { actions, unmount } = mountMediaActions()
|
||||
|
||||
await actions.exportWorkflow()
|
||||
|
||||
expect(mockExtractWorkflowFromAsset).not.toHaveBeenCalled()
|
||||
expect(mockExportWorkflowAction).not.toHaveBeenCalled()
|
||||
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('uses the fallback error when export fails without a message', async () => {
|
||||
mockExportWorkflowAction.mockResolvedValue(failureWithoutError)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.exportWorkflow(createMockAsset())
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'error',
|
||||
detail: 'mediaAsset.failedToExportWorkflow'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('shows no toast when every asset in a bulk export is cancelled', async () => {
|
||||
mockExportWorkflowAction.mockResolvedValue(cancelledResult)
|
||||
const actions = useMediaAssetActions()
|
||||
@@ -500,6 +831,118 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('openMultipleWorkflows', () => {
|
||||
beforeEach(() => {
|
||||
mockExtractWorkflowFromAsset.mockResolvedValue({
|
||||
workflow: { version: 0.4 },
|
||||
filename: 'workflow.json'
|
||||
})
|
||||
})
|
||||
|
||||
it('does nothing for an empty selection', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openMultipleWorkflows([])
|
||||
|
||||
expect(mockOpenWorkflowAction).not.toHaveBeenCalled()
|
||||
expect(useToast().add).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows success when every workflow opens', async () => {
|
||||
mockOpenWorkflowAction.mockResolvedValue({ success: true })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openMultipleWorkflows([
|
||||
createMockAsset({ id: 'a' }),
|
||||
createMockAsset({ id: 'b' })
|
||||
])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
it('shows a missing-workflow warning when none open', async () => {
|
||||
mockOpenWorkflowAction.mockResolvedValue({ success: false })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openMultipleWorkflows([
|
||||
createMockAsset({ id: 'a' }),
|
||||
createMockAsset({ id: 'b' })
|
||||
])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
|
||||
it('shows a partial warning when extraction throws for one asset', async () => {
|
||||
mockExtractWorkflowFromAsset
|
||||
.mockResolvedValueOnce({
|
||||
workflow: { version: 0.4 },
|
||||
filename: 'ok.json'
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('missing workflow'))
|
||||
mockOpenWorkflowAction.mockResolvedValue({ success: true })
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.openMultipleWorkflows([
|
||||
createMockAsset({ id: 'a' }),
|
||||
createMockAsset({ id: 'b' })
|
||||
])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('exportMultipleWorkflows', () => {
|
||||
beforeEach(() => {
|
||||
mockExtractWorkflowFromAsset.mockResolvedValue({
|
||||
workflow: { version: 0.4 },
|
||||
filename: 'workflow.json'
|
||||
})
|
||||
})
|
||||
|
||||
it('does nothing for an empty selection', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.exportMultipleWorkflows([])
|
||||
|
||||
expect(mockExportWorkflowAction).not.toHaveBeenCalled()
|
||||
expect(useToast().add).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows no-workflows warning when every export fails', async () => {
|
||||
mockExportWorkflowAction.mockResolvedValue({
|
||||
success: false,
|
||||
error: 'boom'
|
||||
})
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.exportMultipleWorkflows([
|
||||
createMockAsset({ id: 'a' }),
|
||||
createMockAsset({ id: 'b' })
|
||||
])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
|
||||
it('counts extraction failures as failed exports', async () => {
|
||||
mockExtractWorkflowFromAsset.mockRejectedValue(new Error('missing'))
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.exportMultipleWorkflows([createMockAsset()])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadAssets', () => {
|
||||
it('downloads the injected media asset when called without explicit assets', () => {
|
||||
const mediaAsset = createMockMediaAsset({
|
||||
@@ -534,6 +977,36 @@ describe('useMediaAssetActions', () => {
|
||||
unmount()
|
||||
})
|
||||
|
||||
it('uses the asset URL when no preview URL is available', () => {
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
const asset = createMockAsset({
|
||||
name: 'raw image.png',
|
||||
preview_url: undefined,
|
||||
user_metadata: { subfolder: 'uploads' }
|
||||
})
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
actions.downloadAssets([asset])
|
||||
|
||||
expect(mockDownloadFile).toHaveBeenCalledWith(
|
||||
'http://localhost:8188/api/view?filename=raw+image.png&type=input&subfolder=uploads',
|
||||
'raw image.png'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows an error toast when a direct download throws', () => {
|
||||
mockDownloadFile.mockImplementation(() => {
|
||||
throw new Error('download failed')
|
||||
})
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
actions.downloadAssets([createMockAsset()])
|
||||
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps single explicit assets on the direct download path in cloud', () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetOutputAssetMetadata.mockReturnValue({
|
||||
@@ -943,6 +1416,82 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
expect(payload.naming_strategy).toBe('preserve')
|
||||
})
|
||||
|
||||
it('should include asset ids for imported assets', async () => {
|
||||
mockGetAssetType.mockImplementation((asset: AssetItem) =>
|
||||
asset.tags?.includes('output') ? 'output' : 'input'
|
||||
)
|
||||
const asset1 = createMockAsset({ id: 'input-1', tags: ['input'] })
|
||||
const asset2 = createMockAsset({ id: 'input-2', tags: ['input'] })
|
||||
|
||||
const actions = useMediaAssetActions()
|
||||
actions.downloadAssets([asset1, asset2])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
const payload = mockCreateAssetExport.mock.calls[0][0]
|
||||
expect(payload.job_ids).toBeUndefined()
|
||||
expect(payload.asset_ids).toEqual(['input-1', 'input-2'])
|
||||
expect(payload.naming_strategy).toBe('preserve')
|
||||
})
|
||||
|
||||
it('should mix output job ids and imported asset ids', async () => {
|
||||
mockGetAssetType.mockImplementation((asset: AssetItem) =>
|
||||
asset.tags?.includes('output') ? 'output' : 'input'
|
||||
)
|
||||
const output = createMockAsset({
|
||||
id: 'history-id',
|
||||
name: 'output.png',
|
||||
tags: ['output']
|
||||
})
|
||||
const imported = createMockAsset({ id: 'input-id', tags: ['input'] })
|
||||
|
||||
const actions = useMediaAssetActions()
|
||||
actions.downloadAssets([output, imported])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
const payload = mockCreateAssetExport.mock.calls[0][0]
|
||||
expect(payload.job_ids).toEqual(['history-id'])
|
||||
expect(payload.asset_ids).toEqual(['input-id'])
|
||||
})
|
||||
|
||||
it('should only include a filtered output name once', async () => {
|
||||
const asset1 = createOutputAsset('a1', 'same.png', 'job1')
|
||||
const asset2 = createOutputAsset('a2', 'same.png', 'job1')
|
||||
|
||||
const actions = useMediaAssetActions()
|
||||
actions.downloadAssets([asset1, asset2])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
const payload = mockCreateAssetExport.mock.calls[0][0]
|
||||
expect(payload.job_asset_name_filters).toEqual({
|
||||
job1: ['same.png']
|
||||
})
|
||||
})
|
||||
|
||||
it('should show an error toast when ZIP export creation fails', async () => {
|
||||
mockCreateAssetExport.mockRejectedValueOnce(new Error('export failed'))
|
||||
const asset1 = createOutputAsset('a1', 'img1.png', 'job1')
|
||||
const asset2 = createOutputAsset('a2', 'img2.png', 'job2')
|
||||
|
||||
const actions = useMediaAssetActions()
|
||||
actions.downloadAssets([asset1, asset2])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
expect(mockTrackExport).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadAssets - export toast file count', () => {
|
||||
@@ -1033,6 +1582,200 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteAssets', () => {
|
||||
it('returns false for an empty selection', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
const result = await actions.deleteAssets([])
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(mockShowDialog).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns false when the user cancels', async () => {
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onCancel: () => void } }) => {
|
||||
props.onCancel()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
const result = await actions.deleteAssets(createMockAsset())
|
||||
|
||||
expect(result).toBe(false)
|
||||
expect(mockDeleteAsset).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('rejects imported asset deletion outside cloud mode', async () => {
|
||||
mockIsCloud.value = false
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets(createMockAsset({ tags: ['input'] }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
expect(mockDeleteAsset).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('rejects output deletion when no job id can be resolved', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('output')
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets(
|
||||
createMockAsset({ id: '', name: 'orphan.png', tags: ['output'] })
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
expect(api.deleteItem).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('updates output history and input listings for mixed successful deletion', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockImplementation((asset: AssetItem) =>
|
||||
asset.tags?.includes('output') ? 'output' : 'input'
|
||||
)
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets([
|
||||
createMockAsset({ id: 'history-1', tags: ['output'] }),
|
||||
createMockAsset({ id: 'input-1', tags: ['input'] })
|
||||
])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockUpdateHistory).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockUpdateInputs).toHaveBeenCalled()
|
||||
expect(api.deleteItem).toHaveBeenCalledWith('history', 'history-1')
|
||||
expect(mockDeleteAsset).toHaveBeenCalledWith('input-1')
|
||||
})
|
||||
|
||||
it('skips graph cleanup when there is no root graph', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
mockAppGraph.value = null
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets(createMockAsset({ tags: ['input'] }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockDeleteAsset).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockClearNodePreviewCache).not.toHaveBeenCalled()
|
||||
expect(mockClearWidgetValues).not.toHaveBeenCalled()
|
||||
expect(mockCaptureCanvasState).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses temp widget-value variants when deleting temp assets', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('temp')
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets(
|
||||
createMockAsset({
|
||||
id: 'temp-1',
|
||||
name: 'preview.png',
|
||||
hash: 'preview-hash.png',
|
||||
tags: ['temp']
|
||||
})
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockClearNodePreviewCache).toHaveBeenCalled()
|
||||
})
|
||||
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
|
||||
expect(valuesArg).toEqual(
|
||||
new Set(['preview.png [temp]', 'preview-hash.png'])
|
||||
)
|
||||
})
|
||||
|
||||
it('uses hash-only cleanup values when the asset name is empty', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets(
|
||||
createMockAsset({
|
||||
id: 'hash-only',
|
||||
name: '',
|
||||
hash: 'only-hash.png',
|
||||
tags: ['input']
|
||||
})
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockClearNodePreviewCache).toHaveBeenCalled()
|
||||
})
|
||||
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
|
||||
expect(valuesArg).toEqual(new Set(['only-hash.png']))
|
||||
})
|
||||
|
||||
it('shows a partial warning and cleans up only successfully deleted assets', async () => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
mockDeleteAsset
|
||||
.mockResolvedValueOnce(undefined)
|
||||
.mockRejectedValueOnce(new Error('delete failed'))
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
await actions.deleteAssets([
|
||||
createMockAsset({ id: 'ok', name: 'ok.png', tags: ['input'] }),
|
||||
createMockAsset({ id: 'bad', name: 'bad.png', tags: ['input'] })
|
||||
])
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(useToast().add).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'warn' })
|
||||
)
|
||||
})
|
||||
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
|
||||
expect(valuesArg).toEqual(new Set(['ok.png', 'ok.png [input]']))
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteAssets - model cache invalidation', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createApp, nextTick, ref } from 'vue'
|
||||
@@ -14,6 +15,7 @@ vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetMetadata: vi.fn(),
|
||||
uploadAssetAsync: vi.fn(),
|
||||
uploadAssetFromBase64: vi.fn(),
|
||||
uploadAssetPreviewImage: vi.fn()
|
||||
}
|
||||
}))
|
||||
@@ -248,6 +250,81 @@ describe('useUploadModelWizard', () => {
|
||||
expect(wizard.selectedModelType.value).toBe('checkpoints')
|
||||
})
|
||||
|
||||
it('does not fetch metadata until the URL matches a supported source', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
|
||||
expect(wizard.canFetchMetadata.value).toBe(false)
|
||||
await wizard.fetchMetadata()
|
||||
|
||||
expect(assetService.getAssetMetadata).not.toHaveBeenCalled()
|
||||
expect(wizard.currentStep.value).toBe(1)
|
||||
})
|
||||
|
||||
it('decodes metadata filenames and selects a matching model type tag', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.getAssetMetadata).mockResolvedValue({
|
||||
content_length: 100,
|
||||
final_url: 'https://huggingface.co/org/model',
|
||||
filename: '%E6%A8%A1%E5%9E%8B.safetensors',
|
||||
name: '%E5%90%8D%E7%A8%B1',
|
||||
tags: ['checkpoints'],
|
||||
preview_image: 'data:image/png;base64,abc'
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = ' https://huggingface.co/org/model '
|
||||
|
||||
await wizard.fetchMetadata()
|
||||
|
||||
expect(wizard.currentStep.value).toBe(2)
|
||||
expect(wizard.wizardData.value.url).toBe('https://huggingface.co/org/model')
|
||||
expect(wizard.wizardData.value.name).toBe('模型.safetensors')
|
||||
expect(wizard.wizardData.value.previewImage).toBe(
|
||||
'data:image/png;base64,abc'
|
||||
)
|
||||
expect(wizard.selectedModelType.value).toBe('checkpoints')
|
||||
})
|
||||
|
||||
it('keeps metadata text when percent decoding fails', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.getAssetMetadata).mockResolvedValue({
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/12345',
|
||||
filename: '%E0%A4%A',
|
||||
name: '%E0%A4%A',
|
||||
tags: []
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
await wizard.fetchMetadata()
|
||||
|
||||
expect(wizard.currentStep.value).toBe(2)
|
||||
expect(wizard.wizardData.value.name).toBe('%E0%A4%A')
|
||||
expect(wizard.selectedModelType.value).toBeUndefined()
|
||||
})
|
||||
|
||||
it('uses the fallback metadata error for non-error rejections', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.getAssetMetadata).mockRejectedValue('no metadata')
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
await wizard.fetchMetadata()
|
||||
|
||||
expect(wizard.currentStep.value).toBe(1)
|
||||
expect(wizard.uploadError.value).toBe(
|
||||
'Failed to retrieve metadata. Please check the link and try again.'
|
||||
)
|
||||
})
|
||||
|
||||
it('uploads with the required model type even if selection changes', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
@@ -279,6 +356,382 @@ describe('useUploadModelWizard', () => {
|
||||
expect(result?.modelType).toBe('checkpoints')
|
||||
})
|
||||
|
||||
it('clears upload errors and type mismatches when the URL changes', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-lora',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'loras']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(
|
||||
ref([
|
||||
{ name: 'Checkpoint', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
]),
|
||||
{ requiredModelType: 'checkpoints' }
|
||||
)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
await wizard.uploadModel()
|
||||
|
||||
expect(wizard.uploadTypeMismatch.value).not.toBeNull()
|
||||
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/54321'
|
||||
await nextTick()
|
||||
|
||||
expect(wizard.uploadError.value).toBe('')
|
||||
expect(wizard.uploadTypeMismatch.value).toBeNull()
|
||||
})
|
||||
|
||||
it('returns null while another upload is in progress', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
type UploadResult = Awaited<
|
||||
ReturnType<typeof assetService.uploadAssetAsync>
|
||||
>
|
||||
let resolveUpload!: (value: UploadResult) => void
|
||||
vi.mocked(assetService.uploadAssetAsync).mockReturnValue(
|
||||
new Promise<UploadResult>((resolve) => {
|
||||
resolveUpload = resolve
|
||||
})
|
||||
)
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const firstUpload = wizard.uploadModel()
|
||||
await nextTick()
|
||||
|
||||
await expect(wizard.uploadModel()).resolves.toBeNull()
|
||||
|
||||
resolveUpload({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
}
|
||||
})
|
||||
|
||||
await expect(firstUpload).resolves.toEqual(
|
||||
expect.objectContaining({ status: 'success' })
|
||||
)
|
||||
})
|
||||
|
||||
it('returns null when no model type is selected', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(assetService.uploadAssetAsync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('reports an upload error when no valid source is detected', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://example.com/model'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(assetService.uploadAssetAsync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uploads preview images and passes the preview id to the model upload', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetFromBase64).mockResolvedValue(
|
||||
fromPartial({ id: 'preview-1' })
|
||||
)
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.wizardData.value.metadata = {
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/12345',
|
||||
filename: 'model.safetensors'
|
||||
}
|
||||
wizard.wizardData.value.previewImage = 'data:image/jpeg;base64,abc'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
await wizard.uploadModel()
|
||||
|
||||
expect(assetService.uploadAssetFromBase64).toHaveBeenCalledWith({
|
||||
data: 'data:image/jpeg;base64,abc',
|
||||
name: 'model_preview.jpg',
|
||||
tags: ['preview']
|
||||
})
|
||||
expect(assetService.uploadAssetAsync).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ preview_id: 'preview-1' })
|
||||
)
|
||||
})
|
||||
|
||||
it('continues model upload when preview upload fails', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetFromBase64).mockRejectedValue(
|
||||
new Error('preview failed')
|
||||
)
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.wizardData.value.metadata = {
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/12345',
|
||||
name: 'model'
|
||||
}
|
||||
wizard.wizardData.value.previewImage = 'data:image/webp;base64,abc'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
await wizard.uploadModel()
|
||||
|
||||
expect(assetService.uploadAssetAsync).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ preview_id: undefined })
|
||||
)
|
||||
expect(wizard.uploadStatus.value).toBe('success')
|
||||
})
|
||||
|
||||
it('treats an already completed async upload as success', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'async',
|
||||
task: {
|
||||
task_id: 'task-complete',
|
||||
status: 'completed',
|
||||
message: 'Download complete'
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.wizardData.value.metadata = {
|
||||
content_length: 100,
|
||||
final_url: 'https://civitai.com/models/12345',
|
||||
filename: 'queued.safetensors'
|
||||
}
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toEqual({
|
||||
filename: 'queued.safetensors',
|
||||
modelType: 'checkpoints',
|
||||
status: 'success'
|
||||
})
|
||||
expect(wizard.uploadStatus.value).toBe('success')
|
||||
})
|
||||
|
||||
it('cleans up an immediately resolved async watcher', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
const { useAssetDownloadStore } =
|
||||
await import('@/stores/assetDownloadStore')
|
||||
const assetDownloadStore = useAssetDownloadStore()
|
||||
assetDownloadStore.trackDownload(
|
||||
'task-ready',
|
||||
'checkpoints',
|
||||
'ready.safetensors'
|
||||
)
|
||||
const { api } = await import('@/scripts/api')
|
||||
const handler = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find((c) => c[0] === 'asset_download')?.[1] as
|
||||
| ((e: CustomEvent) => void)
|
||||
| undefined
|
||||
expect(handler).toBeDefined()
|
||||
handler!(
|
||||
new CustomEvent('asset_download', {
|
||||
detail: {
|
||||
task_id: 'task-ready',
|
||||
asset_id: 'asset-ready',
|
||||
asset_name: 'ready.safetensors',
|
||||
bytes_total: 100,
|
||||
bytes_downloaded: 100,
|
||||
progress: 100,
|
||||
status: 'completed' as const
|
||||
}
|
||||
})
|
||||
)
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'async',
|
||||
task: {
|
||||
task_id: 'task-ready',
|
||||
status: 'created',
|
||||
message: 'Download queued'
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
await wizard.uploadModel()
|
||||
await nextTick()
|
||||
|
||||
expect(wizard.uploadStatus.value).toBe('success')
|
||||
})
|
||||
|
||||
it('uses the default failed-download message when no error is available', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'async',
|
||||
task: {
|
||||
task_id: 'task-fallback-fail',
|
||||
status: 'created',
|
||||
message: 'Download queued'
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
await wizard.uploadModel()
|
||||
|
||||
const { api } = await import('@/scripts/api')
|
||||
const handler = vi
|
||||
.mocked(api.addEventListener)
|
||||
.mock.calls.find((c) => c[0] === 'asset_download')?.[1] as
|
||||
| ((e: CustomEvent) => void)
|
||||
| undefined
|
||||
expect(handler).toBeDefined()
|
||||
handler!(
|
||||
new CustomEvent('asset_download', {
|
||||
detail: {
|
||||
task_id: 'task-fallback-fail',
|
||||
asset_id: '',
|
||||
asset_name: '',
|
||||
bytes_total: 1000,
|
||||
bytes_downloaded: 500,
|
||||
progress: 50,
|
||||
status: 'failed' as const
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
await nextTick()
|
||||
|
||||
expect(wizard.uploadStatus.value).toBe('error')
|
||||
expect(wizard.uploadError.value).toBe('assetBrowser.downloadFailed')
|
||||
})
|
||||
|
||||
it('uses fallback labels for unknown mismatch types', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'sync',
|
||||
asset: {
|
||||
id: 'asset-unknown',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models']
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes, {
|
||||
requiredModelType: 'unknown-required'
|
||||
})
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(wizard.uploadTypeMismatch.value).toEqual({
|
||||
importedModelType: undefined,
|
||||
importedModelTypeLabel: undefined,
|
||||
requiredModelType: 'unknown-required',
|
||||
requiredModelTypeLabel: 'unknown-required'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses a generic upload error for non-error upload failures', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockRejectedValue('failed')
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
const result = await wizard.uploadModel()
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(wizard.uploadStatus.value).toBe('error')
|
||||
expect(wizard.uploadError.value).toBe('Failed to upload model')
|
||||
})
|
||||
|
||||
it('navigates backward only after the first step', () => {
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
|
||||
wizard.goToPreviousStep()
|
||||
expect(wizard.currentStep.value).toBe(1)
|
||||
|
||||
wizard.currentStep.value = 3
|
||||
wizard.goToPreviousStep()
|
||||
|
||||
expect(wizard.currentStep.value).toBe(2)
|
||||
})
|
||||
|
||||
it('resets wizard state and cancels pending async status watching', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
|
||||
type: 'async',
|
||||
task: {
|
||||
task_id: 'task-reset',
|
||||
status: 'created',
|
||||
message: 'Download queued'
|
||||
}
|
||||
})
|
||||
|
||||
const wizard = setupUploadModelWizard(modelTypes)
|
||||
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
|
||||
wizard.wizardData.value.name = 'Model'
|
||||
wizard.wizardData.value.tags = ['checkpoints']
|
||||
wizard.selectedModelType.value = 'checkpoints'
|
||||
|
||||
await wizard.uploadModel()
|
||||
wizard.resetWizard()
|
||||
|
||||
expect(wizard.currentStep.value).toBe(1)
|
||||
expect(wizard.uploadStatus.value).toBeUndefined()
|
||||
expect(wizard.uploadError.value).toBe('')
|
||||
expect(wizard.wizardData.value).toEqual({
|
||||
url: '',
|
||||
name: '',
|
||||
tags: []
|
||||
})
|
||||
expect(wizard.selectedModelType.value).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns the synced asset filename for sync imports', async () => {
|
||||
const { assetService } =
|
||||
await import('@/platform/assets/services/assetService')
|
||||
|
||||
@@ -12,6 +12,7 @@ import { api } from '@/scripts/api'
|
||||
|
||||
const mockDistributionState = vi.hoisted(() => ({ isCloud: false }))
|
||||
const mockSettingStoreGet = vi.hoisted(() => vi.fn(() => false))
|
||||
const mockGetCategoryForNodeType = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -33,7 +34,7 @@ vi.mock('@/stores/modelToNodeStore', () => {
|
||||
return {
|
||||
useModelToNodeStore: vi.fn(() => ({
|
||||
getRegisteredNodeTypes: () => registeredNodeTypes,
|
||||
getCategoryForNodeType: vi.fn()
|
||||
getCategoryForNodeType: mockGetCategoryForNodeType
|
||||
}))
|
||||
}
|
||||
})
|
||||
@@ -172,6 +173,28 @@ describe(assetService.getAssetMetadata, () => {
|
||||
).rejects.toThrow('File too large')
|
||||
})
|
||||
|
||||
it('falls back to the unknown localized message for unrecognized error codes', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({ code: 'NOT_A_REAL_CODE' }, { ok: false, status: 400 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.getAssetMetadata('https://example.com/model.safetensors')
|
||||
).rejects.toThrow('Unknown error')
|
||||
})
|
||||
|
||||
it('falls back to unknown when error JSON cannot be parsed', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
json: vi.fn().mockRejectedValue(new Error('bad json'))
|
||||
} as unknown as Response)
|
||||
|
||||
await expect(
|
||||
assetService.getAssetMetadata('https://example.com/model.safetensors')
|
||||
).rejects.toThrow('Unknown error')
|
||||
})
|
||||
|
||||
it('throws a localized message when validation reports is_valid=false', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({
|
||||
@@ -189,6 +212,20 @@ describe(assetService.getAssetMetadata, () => {
|
||||
).rejects.toThrow('Unsafe virus scan')
|
||||
})
|
||||
|
||||
it('falls back to unknown when validation errors are absent', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({
|
||||
content_length: 100,
|
||||
final_url: 'https://example.com/model.safetensors',
|
||||
validation: { is_valid: false }
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.getAssetMetadata('https://example.com/model.safetensors')
|
||||
).rejects.toThrow('Unknown error')
|
||||
})
|
||||
|
||||
it('encodes the URL in the query string', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({
|
||||
@@ -208,12 +245,115 @@ describe(assetService.getAssetMetadata, () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getAssetsForNodeType, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockGetCategoryForNodeType.mockReset()
|
||||
})
|
||||
|
||||
it('returns an empty list for invalid node types without fetching', async () => {
|
||||
await expect(assetService.getAssetsForNodeType('')).resolves.toEqual([])
|
||||
|
||||
expect(fetchApiMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns an empty list when the node type has no asset category', async () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
||||
|
||||
await expect(
|
||||
assetService.getAssetsForNodeType('UnknownNode')
|
||||
).resolves.toEqual([])
|
||||
|
||||
expect(fetchApiMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('fetches category assets with default pagination', async () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
const assets = [
|
||||
validAsset({ id: 'ckpt-1', tags: ['models', 'checkpoints'] })
|
||||
]
|
||||
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse(assets))
|
||||
|
||||
await expect(
|
||||
assetService.getAssetsForNodeType('CheckpointLoaderSimple')
|
||||
).resolves.toEqual(assets)
|
||||
|
||||
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
|
||||
const params = new URL(requestedUrl, 'http://localhost').searchParams
|
||||
expect(params.get('include_tags')).toBe('models,checkpoints')
|
||||
expect(params.get('limit')).toBe('500')
|
||||
expect(params.has('offset')).toBe(false)
|
||||
})
|
||||
|
||||
it('passes positive offsets for category asset pagination', async () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue('loras')
|
||||
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
|
||||
|
||||
await assetService.getAssetsForNodeType('LoraLoader', {
|
||||
limit: 25,
|
||||
offset: 50
|
||||
})
|
||||
|
||||
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
|
||||
const params = new URL(requestedUrl, 'http://localhost').searchParams
|
||||
expect(params.get('include_tags')).toBe('models,loras')
|
||||
expect(params.get('limit')).toBe('25')
|
||||
expect(params.get('offset')).toBe('50')
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getAssetDetails, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('throws when the details response is not ok', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({}, { ok: false, status: 404 })
|
||||
)
|
||||
|
||||
await expect(assetService.getAssetDetails('missing')).rejects.toThrow(
|
||||
'Unable to load asset details for missing: Server returned 404'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when the details response is invalid', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse({ id: 'asset-1' }))
|
||||
|
||||
await expect(assetService.getAssetDetails('asset-1')).rejects.toThrow(
|
||||
/Invalid asset response/
|
||||
)
|
||||
})
|
||||
|
||||
it('returns validated asset details', async () => {
|
||||
const asset = validAsset({ id: 'asset-details' })
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(asset))
|
||||
|
||||
await expect(
|
||||
assetService.getAssetDetails('asset-details')
|
||||
).resolves.toEqual(asset)
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.uploadAssetFromUrl, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
assetService.invalidateInputAssetsIncludingPublic()
|
||||
})
|
||||
|
||||
it('throws when URL upload returns a non-ok response', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 500 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.uploadAssetFromUrl({
|
||||
url: 'https://example.com/input.png',
|
||||
name: 'input.png'
|
||||
})
|
||||
).rejects.toThrow('Failed to upload asset')
|
||||
})
|
||||
|
||||
it('does not invalidate cached input assets when the upload response is invalid', async () => {
|
||||
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
@@ -294,6 +434,61 @@ describe(assetService.uploadAssetFromBase64, () => {
|
||||
expect(fetchApiMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws when base64 upload returns a non-ok response', async () => {
|
||||
const fetchSpy = vi
|
||||
.spyOn(globalThis, 'fetch')
|
||||
.mockResolvedValueOnce(new Response('hello'))
|
||||
try {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 507 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.uploadAssetFromBase64({
|
||||
data: 'data:text/plain;base64,aGVsbG8=',
|
||||
name: 'input.txt'
|
||||
})
|
||||
).rejects.toThrow('Failed to upload asset from base64: 507')
|
||||
} finally {
|
||||
fetchSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('posts base64 uploads with tags and user metadata', async () => {
|
||||
const uploadedAsset = {
|
||||
...validAsset({ id: 'uploaded-input', tags: ['input'] }),
|
||||
created_new: false
|
||||
}
|
||||
const fetchSpy = vi
|
||||
.spyOn(globalThis, 'fetch')
|
||||
.mockResolvedValueOnce(new Response('hello'))
|
||||
try {
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(uploadedAsset))
|
||||
|
||||
const result = await assetService.uploadAssetFromBase64({
|
||||
data: 'data:text/plain;base64,aGVsbG8=',
|
||||
name: 'input.txt',
|
||||
tags: ['input', 'mask'],
|
||||
user_metadata: { source: 'paste' }
|
||||
})
|
||||
|
||||
expect(result).toEqual(uploadedAsset)
|
||||
const request = fetchApiMock.mock.calls[0]?.[1]
|
||||
expect(request).toEqual(expect.objectContaining({ method: 'POST' }))
|
||||
expect(request?.body).toBeInstanceOf(FormData)
|
||||
const formData = request?.body
|
||||
if (!(formData instanceof FormData)) {
|
||||
throw new Error('Expected base64 upload body to be FormData')
|
||||
}
|
||||
expect(formData.get('tags')).toBe(JSON.stringify(['input', 'mask']))
|
||||
expect(formData.get('user_metadata')).toBe(
|
||||
JSON.stringify({ source: 'paste' })
|
||||
)
|
||||
} finally {
|
||||
fetchSpy.mockRestore()
|
||||
}
|
||||
})
|
||||
|
||||
it('does not invalidate cached input assets when the upload response is invalid', async () => {
|
||||
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
@@ -355,6 +550,7 @@ describe(assetService.uploadAssetFromBase64, () => {
|
||||
describe(assetService.uploadAssetAsync, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
assetService.invalidateInputAssetsIncludingPublic()
|
||||
})
|
||||
|
||||
it('returns an async result when the server responds 202', async () => {
|
||||
@@ -389,6 +585,64 @@ describe(assetService.uploadAssetAsync, () => {
|
||||
asset: expect.objectContaining({ id: 'asset-2' })
|
||||
})
|
||||
})
|
||||
|
||||
it('throws when the async upload response is not ok', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 502 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.uploadAssetAsync({
|
||||
source_url: 'https://example.com/model.safetensors'
|
||||
})
|
||||
).rejects.toThrow('Failed to upload asset')
|
||||
})
|
||||
|
||||
it('throws when an async upload task response is invalid', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse({ task_id: 'task-1', status: 'waiting' }, { status: 202 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.uploadAssetAsync({
|
||||
source_url: 'https://example.com/model.safetensors'
|
||||
})
|
||||
).rejects.toThrow('Failed to parse async upload response')
|
||||
})
|
||||
|
||||
it('throws when a sync upload asset response is invalid', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse({ id: 'asset-2' }))
|
||||
|
||||
await expect(
|
||||
assetService.uploadAssetAsync({
|
||||
source_url: 'https://example.com/model.safetensors'
|
||||
})
|
||||
).rejects.toThrow('Failed to parse sync upload response')
|
||||
})
|
||||
|
||||
it('invalidates cached input assets for completed async input uploads', async () => {
|
||||
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
|
||||
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]
|
||||
fetchApiMock
|
||||
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
|
||||
.mockResolvedValueOnce(
|
||||
buildResponse(
|
||||
{ task_id: 'task-1', status: 'completed' },
|
||||
{ ok: true, status: 202 }
|
||||
)
|
||||
)
|
||||
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
|
||||
|
||||
await assetService.getInputAssetsIncludingPublic()
|
||||
await assetService.uploadAssetAsync({
|
||||
source_url: 'https://example.com/input.png',
|
||||
tags: ['input']
|
||||
})
|
||||
const refreshed = await assetService.getInputAssetsIncludingPublic()
|
||||
|
||||
expect(refreshed).toEqual(freshAssets)
|
||||
expect(fetchApiMock).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.deleteAsset, () => {
|
||||
@@ -416,6 +670,94 @@ describe(assetService.deleteAsset, () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.addAssetTags, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
assetService.invalidateInputAssetsIncludingPublic()
|
||||
})
|
||||
|
||||
it('posts tags and returns the parsed tag operation result', async () => {
|
||||
const result = { total_tags: ['input', 'mask'], added: ['mask'] }
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(result))
|
||||
|
||||
await expect(
|
||||
assetService.addAssetTags('asset-1', ['mask'])
|
||||
).resolves.toEqual(result)
|
||||
|
||||
expect(fetchApiMock).toHaveBeenCalledWith(
|
||||
'/assets/asset-1/tags',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ tags: ['mask'] })
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when adding tags fails', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 403 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.addAssetTags('asset-1', ['mask'])
|
||||
).rejects.toThrow(
|
||||
'Unable to add tags to asset asset-1: Server returned 403'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when the add-tags response is invalid', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse({ added: ['mask'] }))
|
||||
|
||||
await expect(
|
||||
assetService.addAssetTags('asset-1', ['mask'])
|
||||
).rejects.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.removeAssetTags, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
assetService.invalidateInputAssetsIncludingPublic()
|
||||
})
|
||||
|
||||
it('deletes tags and returns the parsed tag operation result', async () => {
|
||||
const result = { total_tags: ['input'], removed: ['mask'] }
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(result))
|
||||
|
||||
await expect(
|
||||
assetService.removeAssetTags('asset-1', ['mask'])
|
||||
).resolves.toEqual(result)
|
||||
|
||||
expect(fetchApiMock).toHaveBeenCalledWith(
|
||||
'/assets/asset-1/tags',
|
||||
expect.objectContaining({
|
||||
method: 'DELETE',
|
||||
body: JSON.stringify({ tags: ['mask'] })
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when removing tags fails', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 404 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.removeAssetTags('asset-1', ['mask'])
|
||||
).rejects.toThrow(
|
||||
'Unable to remove tags from asset asset-1: Server returned 404'
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when the remove-tags response is invalid', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse({ removed: ['mask'] }))
|
||||
|
||||
await expect(
|
||||
assetService.removeAssetTags('asset-1', ['mask'])
|
||||
).rejects.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getAssetModelFolders, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -481,6 +823,16 @@ describe(assetService.updateAsset, () => {
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when the update response is not ok', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 409 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.updateAsset('asset-1', { name: 'renamed.safetensors' })
|
||||
).rejects.toThrow('Unable to update asset asset-1: Server returned 409')
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getAssetsByTag, () => {
|
||||
@@ -515,6 +867,21 @@ describe(assetService.getAssetsByTag, () => {
|
||||
expect(params.get('include_tags')).toBe('input')
|
||||
expect(params.get('exclude_tags')).toBe(MISSING_TAG)
|
||||
})
|
||||
|
||||
it('forwards explicit public filtering and offset pagination', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
|
||||
|
||||
await assetService.getAssetsByTag('input', false, {
|
||||
limit: 30,
|
||||
offset: 60
|
||||
})
|
||||
|
||||
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
|
||||
const params = new URL(requestedUrl, 'http://localhost').searchParams
|
||||
expect(params.get('include_public')).toBe('false')
|
||||
expect(params.get('limit')).toBe('30')
|
||||
expect(params.get('offset')).toBe('60')
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getAllAssetsByTag, () => {
|
||||
@@ -562,6 +929,31 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
expect(secondParams.has('offset')).toBe(false)
|
||||
})
|
||||
|
||||
it('uses the default page size when limit is not positive', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
|
||||
|
||||
await expect(
|
||||
assetService.getAllAssetsByTag('input', true, { limit: 0 })
|
||||
).resolves.toEqual([])
|
||||
|
||||
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
|
||||
const params = new URL(requestedUrl, 'http://localhost').searchParams
|
||||
expect(params.get('limit')).toBe('500')
|
||||
})
|
||||
|
||||
it('throws before fetching when the pagination signal is already aborted', async () => {
|
||||
const controller = new AbortController()
|
||||
controller.abort()
|
||||
|
||||
await expect(
|
||||
assetService.getAllAssetsByTag('input', true, {
|
||||
signal: controller.signal
|
||||
})
|
||||
).rejects.toMatchObject({ name: 'AbortError' })
|
||||
|
||||
expect(fetchApiMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('honors has_more when walking tagged asset pages', async () => {
|
||||
fetchApiMock
|
||||
.mockResolvedValueOnce(
|
||||
@@ -703,6 +1095,75 @@ describe(assetService.getAllAssetsByTag, () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.createAssetExport, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('posts export options and returns the export task', async () => {
|
||||
const task = { task_id: 'export-1', status: 'created', message: 'queued' }
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(task))
|
||||
|
||||
await expect(
|
||||
assetService.createAssetExport({
|
||||
asset_ids: ['asset-1'],
|
||||
include_previews: true
|
||||
})
|
||||
).resolves.toEqual(task)
|
||||
|
||||
expect(fetchApiMock).toHaveBeenCalledWith(
|
||||
'/assets/export',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
body: JSON.stringify({
|
||||
asset_ids: ['asset-1'],
|
||||
include_previews: true
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('throws when creating an export fails', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 503 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.createAssetExport({ asset_ids: ['asset-1'] })
|
||||
).rejects.toThrow('Failed to create asset export: 503')
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getExportDownloadUrl, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('returns the export download URL', async () => {
|
||||
const download = {
|
||||
url: 'https://example.com/export.zip',
|
||||
expires_at: '2026-07-01T00:00:00Z'
|
||||
}
|
||||
fetchApiMock.mockResolvedValueOnce(buildResponse(download))
|
||||
|
||||
await expect(
|
||||
assetService.getExportDownloadUrl('export.zip')
|
||||
).resolves.toEqual(download)
|
||||
|
||||
expect(fetchApiMock).toHaveBeenCalledWith('/assets/exports/export.zip')
|
||||
})
|
||||
|
||||
it('throws when export download URL lookup fails', async () => {
|
||||
fetchApiMock.mockResolvedValueOnce(
|
||||
buildResponse(null, { ok: false, status: 404 })
|
||||
)
|
||||
|
||||
await expect(
|
||||
assetService.getExportDownloadUrl('missing.zip')
|
||||
).rejects.toThrow('Failed to get export download URL: 404')
|
||||
})
|
||||
})
|
||||
|
||||
describe(assetService.getInputAssetsIncludingPublic, () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -729,6 +1190,17 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
|
||||
expect(params.get('limit')).toBe('500')
|
||||
})
|
||||
|
||||
it('throws before starting a shared request when the caller signal is already aborted', async () => {
|
||||
const controller = new AbortController()
|
||||
controller.abort()
|
||||
|
||||
await expect(
|
||||
assetService.getInputAssetsIncludingPublic(controller.signal)
|
||||
).rejects.toMatchObject({ name: 'AbortError' })
|
||||
|
||||
expect(fetchApiMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('fetches fresh input assets after explicit invalidation', async () => {
|
||||
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
|
||||
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]
|
||||
|
||||
37
src/platform/assets/utils/assetTypeUtil.test.ts
Normal file
37
src/platform/assets/utils/assetTypeUtil.test.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { getAssetType } from '@/platform/assets/utils/assetTypeUtil'
|
||||
|
||||
function asset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-1',
|
||||
name: 'image.png',
|
||||
preview_url: '',
|
||||
tags: [],
|
||||
created_at: '',
|
||||
updated_at: '',
|
||||
size: 0,
|
||||
mime_type: 'image/png',
|
||||
user_metadata: {},
|
||||
...overrides
|
||||
} as AssetItem
|
||||
}
|
||||
|
||||
describe('getAssetType', () => {
|
||||
it('prefers the preview URL type over tags', () => {
|
||||
expect(
|
||||
getAssetType(
|
||||
asset({
|
||||
preview_url: '/api/view?filename=image.png&type=temp',
|
||||
tags: ['output']
|
||||
})
|
||||
)
|
||||
).toBe('temp')
|
||||
})
|
||||
|
||||
it('falls back to tags and then the supplied default type', () => {
|
||||
expect(getAssetType(asset({ tags: ['input'] }))).toBe('input')
|
||||
expect(getAssetType(asset(), 'input')).toBe('input')
|
||||
})
|
||||
})
|
||||
62
src/platform/assets/utils/assetUrlUtil.test.ts
Normal file
62
src/platform/assets/utils/assetUrlUtil.test.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { getAssetUrl } from '@/platform/assets/utils/assetUrlUtil'
|
||||
|
||||
const { apiURL } = vi.hoisted(() => ({
|
||||
apiURL: vi.fn((path: string) => `https://comfy.local${path}`)
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: { apiURL }
|
||||
}))
|
||||
|
||||
function asset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-1',
|
||||
name: 'folder image.png',
|
||||
preview_url: '',
|
||||
tags: ['output'],
|
||||
created_at: '',
|
||||
updated_at: '',
|
||||
size: 0,
|
||||
mime_type: 'image/png',
|
||||
user_metadata: {},
|
||||
...overrides
|
||||
} as AssetItem
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
apiURL.mockClear()
|
||||
})
|
||||
|
||||
describe('getAssetUrl', () => {
|
||||
it('builds encoded view URLs with type and subfolder', () => {
|
||||
const url = getAssetUrl(
|
||||
asset({
|
||||
user_metadata: { subfolder: 'nested/path' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(apiURL).toHaveBeenCalledWith(
|
||||
'/view?filename=folder+image.png&type=output&subfolder=nested%2Fpath'
|
||||
)
|
||||
expect(url).toBe(
|
||||
'https://comfy.local/view?filename=folder+image.png&type=output&subfolder=nested%2Fpath'
|
||||
)
|
||||
})
|
||||
|
||||
it('uses preview URL type and omits empty subfolders', () => {
|
||||
getAssetUrl(
|
||||
asset({
|
||||
preview_url: '/api/view?filename=image.png&type=temp',
|
||||
tags: ['output'],
|
||||
user_metadata: { subfolder: '' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(apiURL).toHaveBeenCalledWith(
|
||||
'/view?filename=folder+image.png&type=temp'
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
@@ -28,6 +29,8 @@ interface HostAssetWidget extends IBaseWidget<
|
||||
node: LGraphNode
|
||||
}
|
||||
|
||||
type AssetWidget = IBaseWidget<string | undefined, 'asset', IWidgetAssetOptions>
|
||||
|
||||
type OnWidgetChanged = NonNullable<LGraphNode['onWidgetChanged']>
|
||||
|
||||
function checkpointAsset(name: string): AssetItem {
|
||||
@@ -166,4 +169,118 @@ describe('createAssetWidget', () => {
|
||||
)
|
||||
expect(captureCanvasState).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('falls back to widget name and empty current value for cloned widgets', async () => {
|
||||
const { node } = createAssetWidgetNode()
|
||||
const sourceWidget = createAssetWidget({
|
||||
node,
|
||||
widgetName: 'lora_name',
|
||||
nodeTypeForBrowser: 'LoraLoader'
|
||||
})
|
||||
assertAssetOptions(sourceWidget.options)
|
||||
const clonedWidget: AssetWidget = {
|
||||
type: 'asset',
|
||||
name: 'lora_name',
|
||||
value: undefined,
|
||||
options: sourceWidget.options,
|
||||
y: 0
|
||||
}
|
||||
|
||||
await sourceWidget.options.openModal(clonedWidget)
|
||||
|
||||
expect(firstShowOptions()).toMatchObject({
|
||||
nodeType: 'LoraLoader',
|
||||
inputName: 'lora_name',
|
||||
currentValue: ''
|
||||
})
|
||||
})
|
||||
|
||||
it('rejects malformed asset selections', async () => {
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const { node } = createAssetWidgetNode()
|
||||
const widget = createAssetWidget({
|
||||
node,
|
||||
widgetName: 'ckpt_name',
|
||||
nodeTypeForBrowser: 'CheckpointLoaderSimple',
|
||||
defaultValue: 'fake_model.safetensors'
|
||||
})
|
||||
assertAssetOptions(widget.options)
|
||||
|
||||
await widget.options.openModal(widget)
|
||||
firstShowOptions().onAssetSelected?.(
|
||||
fromPartial({ id: 'asset-without-name' })
|
||||
)
|
||||
|
||||
expect(widget.value).toBe('fake_model.safetensors')
|
||||
expect(captureCanvasState).not.toHaveBeenCalled()
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('rejects invalid asset filenames', async () => {
|
||||
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const { node } = createAssetWidgetNode()
|
||||
const widget = createAssetWidget({
|
||||
node,
|
||||
widgetName: 'ckpt_name',
|
||||
nodeTypeForBrowser: 'CheckpointLoaderSimple',
|
||||
defaultValue: 'fake_model.safetensors'
|
||||
})
|
||||
assertAssetOptions(widget.options)
|
||||
|
||||
await widget.options.openModal(widget)
|
||||
firstShowOptions().onAssetSelected?.(checkpointAsset('../bad.safetensors'))
|
||||
|
||||
expect(widget.value).toBe('fake_model.safetensors')
|
||||
expect(captureCanvasState).not.toHaveBeenCalled()
|
||||
consoleError.mockRestore()
|
||||
})
|
||||
|
||||
it('updates ownerless cloned widgets without node callbacks', async () => {
|
||||
const { node, onWidgetChanged } = createAssetWidgetNode()
|
||||
const sourceWidget = createAssetWidget({
|
||||
node,
|
||||
widgetName: 'ckpt_name',
|
||||
nodeTypeForBrowser: 'CheckpointLoaderSimple',
|
||||
defaultValue: 'fake_model.safetensors'
|
||||
})
|
||||
assertAssetOptions(sourceWidget.options)
|
||||
const callback = vi.fn<NonNullable<IBaseWidget['callback']>>()
|
||||
const clonedWidget: AssetWidget = {
|
||||
type: 'asset',
|
||||
name: 'ckpt_name',
|
||||
value: 'fake_model.safetensors',
|
||||
callback,
|
||||
options: sourceWidget.options,
|
||||
y: 0
|
||||
}
|
||||
|
||||
await sourceWidget.options.openModal(clonedWidget)
|
||||
firstShowOptions().onAssetSelected?.(
|
||||
checkpointAsset('real_model.safetensors')
|
||||
)
|
||||
|
||||
expect(clonedWidget.value).toBe('real_model.safetensors')
|
||||
expect(callback).toHaveBeenCalledWith('real_model.safetensors')
|
||||
expect(onWidgetChanged).not.toHaveBeenCalled()
|
||||
expect(captureCanvasState).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('does not capture canvas state when the selection is unchanged', async () => {
|
||||
const { node } = createAssetWidgetNode()
|
||||
const widget = createAssetWidget({
|
||||
node,
|
||||
widgetName: 'ckpt_name',
|
||||
nodeTypeForBrowser: 'CheckpointLoaderSimple',
|
||||
defaultValue: 'fake_model.safetensors'
|
||||
})
|
||||
assertAssetOptions(widget.options)
|
||||
|
||||
await widget.options.openModal(widget)
|
||||
firstShowOptions().onAssetSelected?.(
|
||||
checkpointAsset('fake_model.safetensors')
|
||||
)
|
||||
|
||||
expect(widget.value).toBe('fake_model.safetensors')
|
||||
expect(captureCanvasState).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
|
||||
// Shows loading affordances
|
||||
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
|
||||
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'team_700',
|
||||
'yearly'
|
||||
'yearly',
|
||||
{ paymentIntentSource: 'deep_link' }
|
||||
)
|
||||
// Team never goes through the personal checkout path
|
||||
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()
|
||||
|
||||
@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
return
|
||||
}
|
||||
isTeamCheckout.value = true
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle)
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle, {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
if (isActiveSubscription.value) {
|
||||
await accessBillingPortal(undefined, false)
|
||||
} else {
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
}
|
||||
}, reportError)
|
||||
|
||||
|
||||
@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
|
||||
})
|
||||
|
||||
function handleAddCredits() {
|
||||
telemetry?.trackAddApiCreditButtonClicked()
|
||||
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
function handleUpgradeToAddCredits() {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
}
|
||||
|
||||
async function handleWindowFocus() {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import FreeTierDialogContent from './FreeTierDialogContent.vue'
|
||||
|
||||
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
|
||||
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
function renderComponent() {
|
||||
function renderComponent(props?: { reason?: PaymentIntentSource }) {
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
@@ -23,6 +25,7 @@ function renderComponent() {
|
||||
})
|
||||
|
||||
return render(FreeTierDialogContent, {
|
||||
props,
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
|
||||
renderComponent()
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the generic copy for intent reasons outside the credits variants', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'subscribe_to_run' })
|
||||
expect(
|
||||
screen.getByText('Your credits refresh on Jul 15, 2026.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('swaps to the out-of-credits copy without the refresh line', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="!reason || reason === 'subscription_required'"
|
||||
v-if="!isCreditsBlockedVariant"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -65,10 +65,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="
|
||||
(!reason || reason === 'subscription_required') &&
|
||||
formattedRenewalDate
|
||||
"
|
||||
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -88,7 +85,7 @@
|
||||
@click="$emit('upgrade')"
|
||||
>
|
||||
{{
|
||||
reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
isCreditsBlockedVariant
|
||||
? $t('subscription.freeTier.upgradeCta')
|
||||
: $t('subscription.freeTier.subscribeCta')
|
||||
}}
|
||||
@@ -103,12 +100,12 @@ import { computed } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
|
||||
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
|
||||
defineProps<{
|
||||
reason?: SubscriptionDialogReason
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
|
||||
})
|
||||
|
||||
const freeTierCredits = computed(() => getTierCredits('free'))
|
||||
|
||||
// Only these two variants replace the generic free-tier copy; any other
|
||||
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
|
||||
const isCreditsBlockedVariant = computed(
|
||||
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -261,6 +261,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use the latest userId value when it changes after mount', async () => {
|
||||
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,13 +277,19 @@ import type {
|
||||
TierKey,
|
||||
TierPricing
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import {
|
||||
recordPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
|
||||
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
@@ -321,6 +327,10 @@ interface PricingTierConfig {
|
||||
isPopular?: boolean
|
||||
}
|
||||
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
chooseTeamWorkspace: []
|
||||
}>()
|
||||
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
} as const
|
||||
const previousPlan = currentPlanDescriptor.value
|
||||
const checkoutAttribution = await getCheckoutAttributionForCloud()
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
})
|
||||
}
|
||||
const beginCheckoutMetadata = userId.value
|
||||
? {
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change' as const,
|
||||
...(reason ? { payment_intent_source: reason } : {}),
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
}
|
||||
: null
|
||||
// Pass the target tier to create a deep link to subscription update confirmation
|
||||
const checkoutTier = getCheckoutTier(
|
||||
targetPlan.tierKey,
|
||||
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
|
||||
if (downgrade) {
|
||||
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
|
||||
await accessBillingPortal()
|
||||
const didOpenPortal = await accessBillingPortal()
|
||||
if (didOpenPortal && beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
|
||||
}
|
||||
} else {
|
||||
const didOpenPortal = await accessBillingPortal(checkoutTier)
|
||||
if (!didOpenPortal) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
payment_intent_source: reason,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
|
||||
...(previousPlan
|
||||
? { previous_cycle: previousPlan.billingCycle }
|
||||
: {})
|
||||
})
|
||||
if (beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
beginCheckoutMetadata,
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await performSubscriptionCheckout(
|
||||
tierKey,
|
||||
currentBillingCycle.value,
|
||||
true
|
||||
)
|
||||
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
|
||||
paymentIntentSource: reason
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
|
||||
@@ -56,7 +56,7 @@ const handleSubscribe = () => {
|
||||
current_tier: tier.value?.toLowerCase()
|
||||
})
|
||||
isAwaitingStripeSubscription.value = true
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
|
||||
trackRunButton({ subscribe_to_run: true })
|
||||
}
|
||||
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -48,7 +48,9 @@
|
||||
v-if="isActiveSubscription"
|
||||
variant="primary"
|
||||
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="showSubscriptionDialog"
|
||||
@click="
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.upgradePlan') }}
|
||||
</Button>
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
</i18n-t>
|
||||
</div>
|
||||
|
||||
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
|
||||
<PricingTable
|
||||
:reason
|
||||
class="flex-1"
|
||||
@choose-team-workspace="handleChooseTeam"
|
||||
/>
|
||||
|
||||
<!-- Contact and Enterprise Links -->
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
const { onClose, reason, onChooseTeam } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
onChooseTeam?: () => void
|
||||
}>()
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
|
||||
)
|
||||
return
|
||||
case 'subscription':
|
||||
void dialogService.showSubscriptionRequiredDialog()
|
||||
void dialogService.showSubscriptionRequiredDialog({
|
||||
reason: 'subscription_required'
|
||||
})
|
||||
return
|
||||
case 'credits':
|
||||
void dialogService.showTopUpCreditsDialog({
|
||||
|
||||
@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
describe('usePricingTableUrlLoader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
reason: 'deep_link',
|
||||
planMode: undefined
|
||||
})
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
reason: 'deep_link'
|
||||
})
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
})
|
||||
|
||||
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('denies, strips, and clears together when the user is not eligible', async () => {
|
||||
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({
|
||||
query: { other: 'param' }
|
||||
})
|
||||
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
)
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
|
||||
'pricing'
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
mergePreservedQueryIntoQuery
|
||||
} from '@/platform/navigation/preservedQueryManager'
|
||||
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
|
||||
const planMode =
|
||||
param === 'team' || param === 'personal' ? param : undefined
|
||||
|
||||
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
|
||||
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
|
||||
})
|
||||
}, reportError)
|
||||
|
||||
const showSubscriptionDialog = (options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) => {
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
reason: options?.reason
|
||||
})
|
||||
|
||||
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
|
||||
void showSubscriptionRequiredDialog(options)
|
||||
}
|
||||
|
||||
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
|
||||
await fetchSubscriptionStatus()
|
||||
|
||||
if (!isSubscribedOrIsNotCloud.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
|
||||
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
|
||||
const {
|
||||
mockIsCloud,
|
||||
mockTrackHelpResourceClicked,
|
||||
mockTrackAddApiCreditButtonClicked
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockTrackHelpResourceClicked: vi.fn()
|
||||
mockTrackHelpResourceClicked: vi.fn(),
|
||||
mockTrackAddApiCreditButtonClicked: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () =>
|
||||
mockIsCloud.value
|
||||
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
|
||||
? {
|
||||
trackHelpResourceClicked: mockTrackHelpResourceClicked,
|
||||
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
|
||||
}
|
||||
: null
|
||||
}))
|
||||
|
||||
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
|
||||
const { handleAddApiCredits } = useSubscriptionActions()
|
||||
handleAddApiCredits()
|
||||
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
|
||||
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
|
||||
})
|
||||
|
||||
const handleAddApiCredits = () => {
|
||||
telemetry?.trackAddApiCreditButtonClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockShowLayoutDialog = vi.fn()
|
||||
const mockShowTeamWorkspacesDialog = vi.fn()
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isFreeTier: mockIsFreeTier,
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan,
|
||||
tier: mockTier
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
|
||||
useWorkspaceUI: () => ({
|
||||
permissions: {
|
||||
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsCloud.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
|
||||
const props = mockShowLayoutDialog.mock.calls[0][0].props
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the caller reason and current tier', () => {
|
||||
mockTier.value = 'STANDARD'
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
current_tier: 'standard',
|
||||
reason: 'upgrade_to_add_credits'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'subscribe_to_run' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not track on non-cloud', () => {
|
||||
mockIsCloud.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('show', () => {
|
||||
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
|
||||
expect.objectContaining({ key: 'subscription-required' })
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the reason for the free-tier dialog', () => {
|
||||
mockIsFreeTier.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { show } = useSubscriptionDialog()
|
||||
|
||||
show({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'out_of_credits' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startTeamWorkspaceUpgradeFlow', () => {
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
|
||||
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
|
||||
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
|
||||
|
||||
export type SubscriptionDialogReason =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
|
||||
interface SubscriptionDialogOptions {
|
||||
reason?: SubscriptionDialogReason
|
||||
export interface SubscriptionDialogOptions {
|
||||
reason?: PaymentIntentSource
|
||||
/**
|
||||
* Forces the unified pricing dialog to open on a specific plan tab,
|
||||
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
|
||||
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
|
||||
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
|
||||
}
|
||||
|
||||
// Fired here — the choke point every paywall/pricing dialog variant passes
|
||||
// through — so both the legacy and workspace billing paths emit it.
|
||||
function trackModalOpened(reason?: PaymentIntentSource) {
|
||||
// Resolved lazily to avoid the useBillingContext import cycle (see below).
|
||||
const { tier } = useBillingContext()
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
reason
|
||||
})
|
||||
}
|
||||
|
||||
function showPricingTable(options?: SubscriptionDialogOptions) {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
|
||||
return
|
||||
}
|
||||
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
// Shared dialog shell styling for both variants.
|
||||
const dialogComponentProps = {
|
||||
style: 'width: min(1328px, 95vw); max-height: 958px;',
|
||||
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
|
||||
// (not at composable setup) to avoid the useBillingContext import cycle.
|
||||
const { isFreeTier } = useBillingContext()
|
||||
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
const component = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
|
||||
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
|
||||
sessionStorage.removeItem(RESUME_PRICING_KEY)
|
||||
|
||||
if (!workspaceStore.isInPersonalWorkspace) {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'team_upgrade_resume' })
|
||||
}
|
||||
} catch {
|
||||
// sessionStorage may be unavailable
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
clearPendingSubscriptionCheckoutAttempt,
|
||||
consumePendingSubscriptionCheckoutSuccess,
|
||||
recordPendingSubscriptionCheckoutAttempt
|
||||
} from './subscriptionCheckoutTracker'
|
||||
|
||||
const activeProStatus = {
|
||||
is_active: true,
|
||||
subscription_tier: 'PRO',
|
||||
subscription_duration: 'MONTHLY'
|
||||
} as const
|
||||
|
||||
describe('subscriptionCheckoutTracker', () => {
|
||||
beforeEach(() => {
|
||||
clearPendingSubscriptionCheckoutAttempt()
|
||||
})
|
||||
|
||||
it('round-trips payment_intent_source from attempt to success metadata', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).toMatchObject({
|
||||
tier: 'pro',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('omits payment_intent_source when the attempt had none', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).not.toBeNull()
|
||||
expect(metadata).not.toHaveProperty('payment_intent_source')
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,12 @@ import type {
|
||||
TierKey
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
BeginCheckoutMetadata,
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType,
|
||||
SubscriptionSuccessMetadata
|
||||
} from '@/platform/telemetry/types'
|
||||
|
||||
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
|
||||
const VALID_TIER_KEYS = new Set<TierKey>([
|
||||
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
|
||||
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
|
||||
'comfy:subscription-checkout-attempt-changed'
|
||||
|
||||
type CheckoutType = 'new' | 'change'
|
||||
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
|
||||
|
||||
interface SubscriptionStatusSnapshot {
|
||||
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
|
||||
subscription_duration?: SubscriptionDuration | null
|
||||
}
|
||||
|
||||
interface PendingSubscriptionCheckoutAttempt {
|
||||
export interface PendingSubscriptionCheckoutAttempt {
|
||||
attempt_id: string
|
||||
started_at_ms: number
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface RecordPendingSubscriptionCheckoutAttemptInput {
|
||||
interface PendingSubscriptionCheckoutAttemptInput {
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
const dispatchPendingCheckoutChangeEvent = () => {
|
||||
@@ -168,6 +174,9 @@ const normalizeAttempt = (
|
||||
...(candidate.previous_cycle === 'monthly' ||
|
||||
candidate.previous_cycle === 'yearly'
|
||||
? { previous_cycle: candidate.previous_cycle }
|
||||
: {}),
|
||||
...(typeof candidate.payment_intent_source === 'string'
|
||||
? { payment_intent_source: candidate.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
|
||||
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
|
||||
getPendingSubscriptionCheckoutAttempt() !== null
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: RecordPendingSubscriptionCheckoutAttemptInput
|
||||
export const createPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
const attempt: PendingSubscriptionCheckoutAttempt = {
|
||||
return {
|
||||
attempt_id: createAttemptId(),
|
||||
started_at_ms: Date.now(),
|
||||
tier: input.tier,
|
||||
cycle: input.cycle,
|
||||
checkout_type: input.checkout_type,
|
||||
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
|
||||
...(input.payment_intent_source
|
||||
? { payment_intent_source: input.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
|
||||
export const persistPendingSubscriptionCheckoutAttempt = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
if (!storage) {
|
||||
return attempt
|
||||
}
|
||||
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
return attempt
|
||||
}
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt =>
|
||||
persistPendingSubscriptionCheckoutAttempt(
|
||||
createPendingSubscriptionCheckoutAttempt(input)
|
||||
)
|
||||
|
||||
export const withPendingCheckoutAttemptId = (
|
||||
metadata: BeginCheckoutMetadata,
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): BeginCheckoutMetadata => ({
|
||||
...metadata,
|
||||
checkout_attempt_id: attempt.attempt_id
|
||||
})
|
||||
|
||||
const didAttemptSucceed = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt,
|
||||
status: SubscriptionStatusSnapshot
|
||||
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
|
||||
cycle: attempt.cycle,
|
||||
checkout_type: attempt.checkout_type,
|
||||
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
|
||||
...(attempt.payment_intent_source
|
||||
? { payment_intent_source: attempt.payment_intent_source }
|
||||
: {}),
|
||||
value,
|
||||
currency: 'USD',
|
||||
ecommerce: {
|
||||
|
||||
@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'yearly', true)
|
||||
await performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
ga_client_id: 'ga-client-id',
|
||||
ga_session_id: 'ga-session-id',
|
||||
ga_session_number: 'ga-session-number',
|
||||
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
|
||||
gbraid: 'gbraid-456',
|
||||
wbraid: 'wbraid-789'
|
||||
})
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
JSON.parse(storedAttempt).attempt_id
|
||||
)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'/customers/cloud-subscription-checkout/pro-yearly'
|
||||
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[SubscriptionCheckout] Failed to collect checkout attribution',
|
||||
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
.spyOn(window, 'open')
|
||||
.mockImplementation(() => window as unknown as Window)
|
||||
|
||||
vi.mocked(global.fetch).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', {
|
||||
paymentIntentSource: 'out_of_credits'
|
||||
})
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
|
||||
)
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
const pendingAttempt = JSON.parse(storedAttempt)
|
||||
expect(pendingAttempt).toMatchObject({
|
||||
payment_intent_source: 'out_of_credits'
|
||||
})
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
pendingAttempt.attempt_id
|
||||
)
|
||||
openSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the latest userId when it changes after checkout starts', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
mockUserId.value = 'user-late'
|
||||
authHeader.resolve({ Authorization: 'Bearer test-token' })
|
||||
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-late',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
|
||||
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
const storedAttempt = window.localStorage.getItem(
|
||||
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
|
||||
)
|
||||
expect(storedAttempt).toBeNull()
|
||||
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { getComfyApiBaseUrl } from '@/config/comfyApi'
|
||||
import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import {
|
||||
createPendingSubscriptionCheckoutAttempt,
|
||||
persistPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
type CheckoutTier = TierKey | `${TierKey}-yearly`
|
||||
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
|
||||
return getCheckoutAttribution()
|
||||
}
|
||||
|
||||
interface PerformSubscriptionCheckoutOptions {
|
||||
openInNewTab?: boolean
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Core subscription checkout logic shared between PricingTable and
|
||||
* SubscriptionRedirectView. Handles:
|
||||
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
|
||||
export async function performSubscriptionCheckout(
|
||||
tierKey: TierKey,
|
||||
currentBillingCycle: BillingCycle,
|
||||
openInNewTab: boolean = true
|
||||
options: PerformSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
const { openInNewTab = true, paymentIntentSource } = options
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { userId } = storeToRefs(authStore)
|
||||
const telemetry = useTelemetry()
|
||||
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
|
||||
const data = await response.json()
|
||||
|
||||
if (data.checkout_url) {
|
||||
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...checkoutAttribution
|
||||
})
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
{
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {}),
|
||||
...checkoutAttribution
|
||||
},
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (openInNewTab) {
|
||||
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
|
||||
if (!checkoutWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
} else {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
globalThis.location.href = data.checkout_url
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn()
|
||||
}))
|
||||
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
|
||||
vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
|
||||
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
|
||||
workspaceApi: { subscribe: mockSubscribe }
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
|
||||
}))
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
|
||||
|
||||
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
billing_op_id: 'op_1'
|
||||
})
|
||||
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly', {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
|
||||
returnUrl: 'https://app.test/payment/success',
|
||||
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
teamCreditStopId: 'team_700'
|
||||
})
|
||||
expect(assignedHref).toBe('https://stripe.test/pay')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'team',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op_1',
|
||||
payment_intent_source: 'deep_link'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
|
||||
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
expect(assignedHref).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not track begin_checkout when subscribe fails', async () => {
|
||||
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
|
||||
|
||||
await expect(
|
||||
performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
).rejects.toThrow('subscribe failed')
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does nothing off cloud', async () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
|
||||
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
interface PerformTeamSubscriptionCheckoutOptions {
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
|
||||
* link: subscribes to the per-credit Team plan at the chosen slider stop and
|
||||
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
|
||||
*/
|
||||
export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId: string,
|
||||
billingCycle: BillingCycle
|
||||
billingCycle: BillingCycle,
|
||||
options: PerformTeamSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId
|
||||
})
|
||||
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType: 'new',
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource: options.paymentIntentSource
|
||||
})
|
||||
|
||||
if (response.status === 'needs_payment_method') {
|
||||
// A needs_payment_method response without a URL is unusable: surface it to
|
||||
// the caller's error handling rather than silently dropping the user home
|
||||
|
||||
@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
|
||||
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
|
||||
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
|
||||
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
|
||||
const b: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const metadata = {
|
||||
user_id: 'user-1',
|
||||
tier: 'pro' as const,
|
||||
cycle: 'monthly' as const,
|
||||
checkout_type: 'new' as const,
|
||||
payment_intent_source: 'subscribe_to_run' as const
|
||||
}
|
||||
registry.trackBeginCheckout(metadata)
|
||||
|
||||
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
|
||||
})
|
||||
|
||||
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
|
||||
const provider: TelemetryProvider = {
|
||||
trackAddApiCreditButtonClicked: vi.fn()
|
||||
}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(provider)
|
||||
|
||||
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(
|
||||
provider.trackAddApiCreditButtonClicked
|
||||
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
|
||||
})
|
||||
|
||||
it('skips providers that do not implement trackSearchQuery', () => {
|
||||
const empty: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackBeginCheckout({
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.BEGIN_CHECKOUT,
|
||||
{
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('captures add-credit clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'credits_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures share attribution events', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
EnterLinearMetadata,
|
||||
ShareFlowMetadata,
|
||||
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
|
||||
}
|
||||
|
||||
trackMonthlySubscriptionSucceeded(
|
||||
|
||||
@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'avatar_menu' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when the host bridge is absent', () => {
|
||||
delete window.__comfyDesktop2
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -12,12 +12,29 @@
|
||||
* 3. Check dist/assets/*.js files contain no tracking code
|
||||
*/
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
|
||||
export type PaymentIntentSource =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
| 'subscribe_to_run'
|
||||
| 'subscribe_now_button'
|
||||
| 'upgrade_to_add_credits'
|
||||
| 'settings_billing_panel'
|
||||
| 'avatar_menu_plans'
|
||||
| 'team_members_panel'
|
||||
| 'invite_member_upsell'
|
||||
| 'upload_model_upgrade'
|
||||
| 'team_upgrade_resume'
|
||||
|
||||
export type SubscriptionCheckoutType = 'new' | 'change'
|
||||
export type SubscriptionCheckoutTier = TierKey | 'team'
|
||||
|
||||
/**
|
||||
* Authentication metadata for sign-up tracking
|
||||
*/
|
||||
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
|
||||
|
||||
export interface SubscriptionMetadata {
|
||||
current_tier?: string
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
tier: TierKey
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
checkout_attempt_id?: string
|
||||
billing_op_id?: string
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface EcommerceItemMetadata {
|
||||
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
|
||||
checkout_attempt_id: string
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
value: number
|
||||
currency: string
|
||||
ecommerce: EcommerceMetadata
|
||||
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackAddApiCreditButtonClicked?(): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void
|
||||
|
||||
@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -391,12 +391,13 @@ const showZeroState = computed(
|
||||
)
|
||||
|
||||
function handleSubscribeWorkspace() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleUpgrade() {
|
||||
if (isFreeTierPlan.value) showPricingTable()
|
||||
else showSubscriptionDialog()
|
||||
if (isFreeTierPlan.value)
|
||||
showPricingTable({ reason: 'settings_billing_panel' })
|
||||
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleViewMoreDetails() {
|
||||
|
||||
@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { useEventListener } from '@vueuse/core'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
|
||||
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
|
||||
|
||||
const { onClose, reason, initialPlanMode } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
initialPlanMode?: 'personal' | 'team'
|
||||
}>()
|
||||
|
||||
@@ -152,7 +152,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleTeamSubscribe,
|
||||
handleResubscribe
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
|
||||
// Backspace mirrors the back arrow on the confirm step, but never while an
|
||||
// editable element is focused (let it delete text there).
|
||||
|
||||
@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
|
||||
|
||||
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
|
||||
const mockHandleSuccessClose = vi.fn()
|
||||
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
|
||||
const mockPreviewData = ref<{ transition_type: string } | null>(null)
|
||||
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
|
||||
useSubscriptionCheckout: () => ({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
useSubscriptionCheckout: mockUseSubscriptionCheckout
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
@@ -91,7 +76,7 @@ const SuccessStub = {
|
||||
function renderComponent(
|
||||
props: {
|
||||
onClose?: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
} = {}
|
||||
) {
|
||||
@@ -121,6 +106,23 @@ function renderComponent(
|
||||
describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseSubscriptionCheckout.mockReturnValue({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
mockCheckoutStep.value = 'pricing'
|
||||
mockPreviewData.value = null
|
||||
})
|
||||
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('passes the reason into subscription checkout', () => {
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
'out_of_credits'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows the team workspace header by default', () => {
|
||||
renderComponent()
|
||||
expect(screen.getByText('Team Workspace')).toBeInTheDocument()
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import PricingTableWorkspace from './PricingTableWorkspace.vue'
|
||||
@@ -130,7 +130,7 @@ const {
|
||||
isPersonal = false
|
||||
} = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
}>()
|
||||
|
||||
@@ -154,7 +154,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleResubscribe,
|
||||
handleSuccessClose
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -61,6 +61,9 @@ function onDismiss() {
|
||||
|
||||
function onUpgrade() {
|
||||
dialogStore.closeDialog({ key: 'invite-member-upsell' })
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({
|
||||
planMode: 'team',
|
||||
reason: 'invite_member_upsell'
|
||||
})
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -277,7 +277,7 @@ export function useMembersPanel() {
|
||||
}
|
||||
|
||||
function showTeamPlans() {
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed } from 'vue'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import type { Plan } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
import { findPlanSlug } from './useSubscriptionCheckout'
|
||||
@@ -75,7 +76,9 @@ const {
|
||||
mockPlans,
|
||||
mockResubscribe,
|
||||
mockToastAdd,
|
||||
mockStartOperation
|
||||
mockStartOperation,
|
||||
mockTrackBeginCheckout,
|
||||
mockUserId
|
||||
} = vi.hoisted(() => ({
|
||||
mockSubscribe: vi.fn(),
|
||||
mockPreviewSubscribe: vi.fn(),
|
||||
@@ -84,7 +87,9 @@ const {
|
||||
mockPlans: { value: [] as Plan[] },
|
||||
mockResubscribe: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockStartOperation: vi.fn()
|
||||
mockStartOperation: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
|
||||
useTelemetry: () => ({
|
||||
trackMonthlySubscriptionSucceeded: vi.fn(),
|
||||
trackBeginCheckout: mockTrackBeginCheckout
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
|
||||
describe('useSubscriptionCheckout', () => {
|
||||
let emit: ReturnType<typeof vi.fn>
|
||||
|
||||
async function setup() {
|
||||
async function setup(paymentIntentSource?: PaymentIntentSource) {
|
||||
const { useSubscriptionCheckout } =
|
||||
await import('./useSubscriptionCheckout')
|
||||
return useSubscriptionCheckout(emit as never)
|
||||
return useSubscriptionCheckout(emit as never, paymentIntentSource)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPlans.value = allPlans()
|
||||
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
|
||||
mockUserId.value = 'user-1'
|
||||
emit = vi.fn()
|
||||
})
|
||||
|
||||
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
|
||||
cancelUrl: 'https://platform.comfy.org/payment/failed'
|
||||
})
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-team-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the annual plan slug for the yearly cycle', async () => {
|
||||
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Team payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps team checkout_type as change when the preview request fails', async () => {
|
||||
const checkout = await setup()
|
||||
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
|
||||
await checkout.handleSubscribeTeamClick({
|
||||
stop: {
|
||||
id: 'team_1400',
|
||||
usd: 1400,
|
||||
credits: 295_400,
|
||||
discountedUsd: 1295
|
||||
},
|
||||
billingCycle: 'monthly',
|
||||
isChange: true
|
||||
})
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleTeamSubscribe()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'change',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
})
|
||||
|
||||
it('skips begin_checkout when no user id is available', async () => {
|
||||
mockUserId.value = null
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
mockUserId.value = 'user-1'
|
||||
})
|
||||
|
||||
it('fires begin_checkout carrying the payment intent source', async () => {
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-1',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('opens payment URL when needs_payment_method', async () => {
|
||||
const checkout = await setup()
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type {
|
||||
Plan,
|
||||
PreviewSubscribeResponse,
|
||||
SubscribeResponse
|
||||
} from '@/platform/workspace/api/workspaceApi'
|
||||
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
type CheckoutStep = 'pricing' | 'preview' | 'success'
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
|
||||
interface SelectedTeamCheckout {
|
||||
stop: TeamPlanSelection
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
}
|
||||
|
||||
/**
|
||||
* Which screen the `preview` step shows. Only a change prorates: a team change
|
||||
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
|
||||
@@ -45,9 +55,12 @@ export function findPlanSlug(
|
||||
return plan?.slug ?? null
|
||||
}
|
||||
|
||||
export function useSubscriptionCheckout(emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
}) {
|
||||
export function useSubscriptionCheckout(
|
||||
emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
},
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
) {
|
||||
const { t } = useI18n()
|
||||
const toast = useToast()
|
||||
const {
|
||||
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
|
||||
const isResubscribing = ref(false)
|
||||
const previewData = ref<PreviewSubscribeResponse | null>(null)
|
||||
const selectedTierKey = ref<CheckoutTierKey | null>(null)
|
||||
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
|
||||
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
|
||||
const selectedBillingCycle = ref<BillingCycle>('yearly')
|
||||
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
|
||||
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
|
||||
const selectedTeamStop = computed(
|
||||
() => selectedTeamCheckout.value?.stop ?? null
|
||||
)
|
||||
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
|
||||
|
||||
const previewVariant = computed<PreviewVariant>(() => {
|
||||
if (selectedTeamStop.value) {
|
||||
if (selectedTeamCheckout.value) {
|
||||
return previewData.value ? 'team-change' : 'team-new'
|
||||
}
|
||||
if (previewData.value) {
|
||||
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
|
||||
billingCycle: BillingCycle
|
||||
isChange?: boolean
|
||||
}) {
|
||||
selectedTeamStop.value = payload.stop
|
||||
selectedTeamCheckout.value = {
|
||||
stop: payload.stop,
|
||||
checkoutType: payload.isChange ? 'change' : 'new'
|
||||
}
|
||||
selectedBillingCycle.value = payload.billingCycle
|
||||
selectedTierKey.value = null
|
||||
previewData.value = null
|
||||
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
|
||||
function handleBackToPricing() {
|
||||
checkoutStep.value = 'pricing'
|
||||
previewData.value = null
|
||||
selectedTeamStop.value = null
|
||||
selectedTeamCheckout.value = null
|
||||
}
|
||||
|
||||
function handleSuccessClose() {
|
||||
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleSubscription() {
|
||||
if (!selectedTierKey.value) return
|
||||
const tierKey = selectedTierKey.value
|
||||
if (!tierKey) return
|
||||
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
const checkoutType =
|
||||
previewData.value &&
|
||||
previewData.value.transition_type !== 'new_subscription'
|
||||
? 'change'
|
||||
: 'new'
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getApiPlanSlug(
|
||||
selectedTierKey.value,
|
||||
selectedBillingCycle.value
|
||||
)
|
||||
const planSlug = getApiPlanSlug(tierKey, billingCycle)
|
||||
if (!planSlug) return
|
||||
const response = await subscribe(planSlug, {
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: tierKey,
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleTeamSubscription() {
|
||||
const stop = selectedTeamStop.value
|
||||
if (!stop?.id) {
|
||||
const teamCheckout = selectedTeamCheckout.value
|
||||
if (!teamCheckout?.stop.id) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('subscription.teamPlan.name'),
|
||||
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
|
||||
return
|
||||
}
|
||||
|
||||
const { stop, checkoutType } = teamCheckout
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
|
||||
const planSlug = getTeamPlanSlug(billingCycle)
|
||||
const response = await subscribe(planSlug, {
|
||||
teamCreditStopId: stop.id,
|
||||
billingCycle: selectedBillingCycle.value,
|
||||
billingCycle,
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
|
||||
|
||||
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
|
||||
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingBalanceResponse,
|
||||
BillingStatusResponse,
|
||||
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
subscriptionDialog.show()
|
||||
subscriptionDialog.show({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
subscriptionDialog.show()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
subscriptionDialog.show(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutTier,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
interface TrackWorkspaceCheckoutStartedOptions {
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
billingOpId: string
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export function trackWorkspaceCheckoutStarted({
|
||||
tier,
|
||||
cycle,
|
||||
checkoutType,
|
||||
billingOpId,
|
||||
paymentIntentSource
|
||||
}: TrackWorkspaceCheckoutStartedOptions) {
|
||||
const { userId } = useAuthStore()
|
||||
if (!userId) return
|
||||
|
||||
useTelemetry()?.trackBeginCheckout({
|
||||
user_id: userId,
|
||||
tier,
|
||||
cycle,
|
||||
checkout_type: checkoutType,
|
||||
billing_op_id: billingOpId,
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {})
|
||||
})
|
||||
}
|
||||
286
src/services/comfyRegistryService.test.ts
Normal file
286
src/services/comfyRegistryService.test.ts
Normal file
@@ -0,0 +1,286 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
interface AxiosLikeError extends Error {
|
||||
isAxiosError: true
|
||||
response?: {
|
||||
status: number
|
||||
data?: {
|
||||
message?: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const mockClient = vi.hoisted(() => ({
|
||||
get: vi.fn(),
|
||||
post: vi.fn()
|
||||
}))
|
||||
|
||||
const mockAxios = vi.hoisted(() => ({
|
||||
create: vi.fn(() => mockClient),
|
||||
isAxiosError: vi.fn(
|
||||
(error: unknown): error is AxiosLikeError =>
|
||||
typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'isAxiosError' in error &&
|
||||
error.isAxiosError === true
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('axios', () => ({
|
||||
default: mockAxios
|
||||
}))
|
||||
|
||||
import { useComfyRegistryService } from './comfyRegistryService'
|
||||
|
||||
function response<T>(data: T) {
|
||||
return { data }
|
||||
}
|
||||
|
||||
function axiosError(
|
||||
message: string,
|
||||
responseData?: AxiosLikeError['response']
|
||||
): AxiosLikeError {
|
||||
const error = new Error(message) as AxiosLikeError
|
||||
error.isAxiosError = true
|
||||
if (responseData) error.response = responseData
|
||||
return error
|
||||
}
|
||||
|
||||
describe('useComfyRegistryService', () => {
|
||||
beforeEach(() => {
|
||||
mockClient.get.mockReset()
|
||||
mockClient.post.mockReset()
|
||||
mockAxios.isAxiosError.mockClear()
|
||||
})
|
||||
|
||||
it('configures the registry axios client with repeated query params', () => {
|
||||
expect(mockAxios.create).toHaveBeenCalledWith({
|
||||
baseURL: 'https://api.comfy.org',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
paramsSerializer: {
|
||||
indexes: null
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('returns response data and clears loading state for successful requests', async () => {
|
||||
mockClient.get.mockResolvedValueOnce(response({ nodes: [] }))
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
const result = await service.search({ search: 'manager' })
|
||||
|
||||
expect(result).toEqual({ nodes: [] })
|
||||
expect(mockClient.get).toHaveBeenCalledWith('/nodes/search', {
|
||||
params: { search: 'manager' },
|
||||
signal: undefined
|
||||
})
|
||||
expect(service.error.value).toBeNull()
|
||||
expect(service.isLoading.value).toBe(false)
|
||||
})
|
||||
|
||||
it('skips node definition requests when pack id or version is missing', async () => {
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
await expect(
|
||||
service.getNodeDefs({ packId: '', version: '1.0.0' })
|
||||
).resolves.toBeNull()
|
||||
await expect(
|
||||
service.getNodeDefs({ packId: 'pack', version: '' })
|
||||
).resolves.toBeNull()
|
||||
expect(mockClient.get).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('passes query params and abort signals through node definition requests', async () => {
|
||||
const signal = new AbortController().signal
|
||||
mockClient.get.mockResolvedValueOnce(response([{ name: 'KSampler' }]))
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
const result = await service.getNodeDefs(
|
||||
{ packId: 'pack', version: '1.0.0', page: 2 },
|
||||
signal
|
||||
)
|
||||
|
||||
expect(result).toEqual([{ name: 'KSampler' }])
|
||||
expect(mockClient.get).toHaveBeenCalledWith(
|
||||
'/nodes/pack/versions/1.0.0/comfy-nodes',
|
||||
{
|
||||
params: { page: 2 },
|
||||
signal
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('routes publisher, pack, and review methods to their registry endpoints', async () => {
|
||||
mockClient.get
|
||||
.mockResolvedValueOnce(response({ id: 'publisher' }))
|
||||
.mockResolvedValueOnce(response([{ id: 'pack' }]))
|
||||
.mockResolvedValueOnce(response([{ version: '1.0.0' }]))
|
||||
.mockResolvedValueOnce(response({ id: 'version' }))
|
||||
.mockResolvedValueOnce(response({ id: 'pack' }))
|
||||
.mockResolvedValueOnce(response({ id: 'pack' }))
|
||||
.mockResolvedValueOnce(response({ id: 'pack' }))
|
||||
mockClient.post
|
||||
.mockResolvedValueOnce(response({ id: 'reviewed' }))
|
||||
.mockResolvedValueOnce(response({ node_versions: [] }))
|
||||
const service = useComfyRegistryService()
|
||||
const signal = new AbortController().signal
|
||||
|
||||
await expect(
|
||||
service.getPublisherById('publisher', signal)
|
||||
).resolves.toEqual({ id: 'publisher' })
|
||||
await expect(
|
||||
service.listPacksForPublisher('publisher', true, signal)
|
||||
).resolves.toEqual([{ id: 'pack' }])
|
||||
await expect(
|
||||
service.getPackVersions(
|
||||
'pack',
|
||||
{ statuses: ['NodeVersionStatusActive'] },
|
||||
signal
|
||||
)
|
||||
).resolves.toEqual([{ version: '1.0.0' }])
|
||||
await expect(
|
||||
service.getPackByVersion('pack', 'version', signal)
|
||||
).resolves.toEqual({ id: 'version' })
|
||||
await expect(service.getPackById('pack', signal)).resolves.toEqual({
|
||||
id: 'pack'
|
||||
})
|
||||
await expect(
|
||||
service.inferPackFromNodeName('KSampler', signal)
|
||||
).resolves.toEqual({ id: 'pack' })
|
||||
await expect(service.listAllPacks({ page: 1 }, signal)).resolves.toEqual({
|
||||
id: 'pack'
|
||||
})
|
||||
await expect(service.postPackReview('pack', 5, signal)).resolves.toEqual({
|
||||
id: 'reviewed'
|
||||
})
|
||||
await expect(
|
||||
service.getBulkNodeVersions(
|
||||
[{ node_id: 'pack', version: '1.0.0' }],
|
||||
signal
|
||||
)
|
||||
).resolves.toEqual({ node_versions: [] })
|
||||
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(1, '/publishers/publisher', {
|
||||
signal
|
||||
})
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'/publishers/publisher/nodes',
|
||||
{
|
||||
params: { include_banned: true },
|
||||
signal
|
||||
}
|
||||
)
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(3, '/nodes/pack/versions', {
|
||||
params: { statuses: ['NodeVersionStatusActive'] },
|
||||
signal
|
||||
})
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(
|
||||
4,
|
||||
'/nodes/pack/versions/version',
|
||||
{ signal }
|
||||
)
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(5, '/nodes/pack', {
|
||||
signal
|
||||
})
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(
|
||||
6,
|
||||
'/comfy-nodes/KSampler/node',
|
||||
{ signal }
|
||||
)
|
||||
expect(mockClient.get).toHaveBeenNthCalledWith(7, '/nodes', {
|
||||
params: { page: 1 },
|
||||
signal
|
||||
})
|
||||
expect(mockClient.post).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'/nodes/pack/reviews',
|
||||
null,
|
||||
{ params: { star: 5 }, signal }
|
||||
)
|
||||
expect(mockClient.post).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'/bulk/nodes/versions',
|
||||
{ node_versions: [{ node_id: 'pack', version: '1.0.0' }] },
|
||||
{ signal }
|
||||
)
|
||||
})
|
||||
|
||||
it('omits include_banned when listing publisher packs without banned packs', async () => {
|
||||
mockClient.get.mockResolvedValueOnce(response([]))
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
await service.listPacksForPublisher('publisher', false)
|
||||
|
||||
expect(mockClient.get).toHaveBeenCalledWith('/publishers/publisher/nodes', {
|
||||
params: undefined,
|
||||
signal: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it.for([
|
||||
{ status: 400, expected: 'Bad request: Invalid input' },
|
||||
{ status: 401, expected: 'Unauthorized: Authentication required' },
|
||||
{ status: 403, expected: 'Forbidden: Access denied' },
|
||||
{ status: 404, expected: 'Not found: Resource not found' },
|
||||
{ status: 409, expected: 'Conflict: Resource conflict' },
|
||||
{ status: 500, expected: 'Server error: Internal server error' },
|
||||
{ status: 418, expected: 'Failed to perform search: teapot' }
|
||||
])(
|
||||
'normalizes axios response status $status',
|
||||
async ({ status, expected }) => {
|
||||
mockClient.get.mockRejectedValueOnce(
|
||||
axiosError('Request failed', {
|
||||
status,
|
||||
data: status === 418 ? { message: 'teapot' } : {}
|
||||
})
|
||||
)
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
await expect(service.search()).resolves.toBeNull()
|
||||
|
||||
expect(service.error.value).toBe(expected)
|
||||
expect(service.isLoading.value).toBe(false)
|
||||
}
|
||||
)
|
||||
|
||||
it('uses route-specific errors before generic status messages', async () => {
|
||||
mockClient.get.mockRejectedValueOnce(
|
||||
axiosError('Request failed', {
|
||||
status: 404,
|
||||
data: { message: 'ignored' }
|
||||
})
|
||||
)
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
await expect(service.getPackById('missing')).resolves.toBeNull()
|
||||
|
||||
expect(service.error.value).toBe(
|
||||
'Pack not found: The pack with ID missing does not exist'
|
||||
)
|
||||
})
|
||||
|
||||
it('normalizes network, thrown Error, unknown, and abort failures', async () => {
|
||||
const service = useComfyRegistryService()
|
||||
|
||||
mockClient.get.mockRejectedValueOnce(axiosError('Network down'))
|
||||
await expect(service.search()).resolves.toBeNull()
|
||||
expect(service.error.value).toBe('Failed to perform search: Network down')
|
||||
|
||||
mockClient.get.mockRejectedValueOnce(new Error('boom'))
|
||||
await expect(service.search()).resolves.toBeNull()
|
||||
expect(service.error.value).toBe('Failed to perform search: boom')
|
||||
|
||||
mockClient.get.mockRejectedValueOnce('bad')
|
||||
await expect(service.search()).resolves.toBeNull()
|
||||
expect(service.error.value).toBe(
|
||||
'Failed to perform search: Unknown error occurred'
|
||||
)
|
||||
|
||||
mockClient.get.mockRejectedValueOnce(new DOMException('', 'AbortError'))
|
||||
await expect(service.search()).resolves.toBeNull()
|
||||
expect(service.error.value).toBeNull()
|
||||
})
|
||||
})
|
||||
@@ -18,7 +18,7 @@ import type {
|
||||
} from '@/stores/dialogStore'
|
||||
|
||||
import type { ComponentAttrs } from 'vue-component-type-helpers'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
// Lazy loaders for dialogs - components are loaded on first use
|
||||
@@ -442,9 +442,9 @@ export const useDialogService = () => {
|
||||
})
|
||||
}
|
||||
|
||||
async function showSubscriptionRequiredDialog(options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) {
|
||||
async function showSubscriptionRequiredDialog(
|
||||
options?: SubscriptionDialogOptions
|
||||
) {
|
||||
if (!isCloud || !window.__CONFIG__?.subscription_required) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -234,6 +234,54 @@ describe('useRegistrySearchGateway', () => {
|
||||
const gateway = useRegistrySearchGateway()
|
||||
expect(gateway).toBeDefined()
|
||||
})
|
||||
|
||||
it('waits for the circuit breaker timeout before retrying a failed provider', async () => {
|
||||
vi.useFakeTimers()
|
||||
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'))
|
||||
|
||||
vi.mocked(useAlgoliaSearchProvider).mockImplementation(() => {
|
||||
throw new Error('Algolia init failed')
|
||||
})
|
||||
|
||||
const registryResult = {
|
||||
nodePacks: [{ id: 'registry-1', name: 'Registry Pack' }],
|
||||
querySuggestions: []
|
||||
}
|
||||
|
||||
const mockRegistryProvider = {
|
||||
searchPacks: vi.fn().mockRejectedValue(new Error('Registry failed')),
|
||||
clearSearchCache: vi.fn(),
|
||||
getSortValue: vi.fn(),
|
||||
getSortableFields: vi.fn().mockReturnValue([])
|
||||
}
|
||||
|
||||
vi.mocked(useComfyRegistrySearchProvider).mockReturnValue(
|
||||
mockRegistryProvider
|
||||
)
|
||||
|
||||
const gateway = useRegistrySearchGateway()
|
||||
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
await expect(
|
||||
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
|
||||
).rejects.toThrow('All search providers failed')
|
||||
}
|
||||
|
||||
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(3)
|
||||
|
||||
await expect(
|
||||
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
|
||||
).rejects.toThrow('All search providers failed')
|
||||
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(3)
|
||||
|
||||
vi.setSystemTime(new Date('2024-01-01T00:01:01Z'))
|
||||
mockRegistryProvider.searchPacks.mockResolvedValueOnce(registryResult)
|
||||
|
||||
await expect(
|
||||
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
|
||||
).resolves.toBe(registryResult)
|
||||
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cache management', () => {
|
||||
|
||||
@@ -126,6 +126,19 @@ describe('useAssetDownloadStore', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps the first placeholder when the same task is tracked twice', () => {
|
||||
const store = useAssetDownloadStore()
|
||||
|
||||
store.trackDownload('task-123', 'checkpoints', 'first.safetensors')
|
||||
store.trackDownload('task-123', 'loras', 'second.safetensors')
|
||||
|
||||
expect(store.downloadList).toHaveLength(1)
|
||||
expect(store.downloadList[0]).toMatchObject({
|
||||
modelType: 'checkpoints',
|
||||
assetName: 'first.safetensors'
|
||||
})
|
||||
})
|
||||
|
||||
it('handles out-of-order messages where completed arrives before progress', () => {
|
||||
const store = useAssetDownloadStore()
|
||||
|
||||
@@ -179,6 +192,19 @@ describe('useAssetDownloadStore', () => {
|
||||
expect(store.finishedDownloads[0].status).toBe('completed')
|
||||
})
|
||||
|
||||
it('skips polling when active downloads have fresh progress', async () => {
|
||||
const store = useAssetDownloadStore()
|
||||
|
||||
dispatch(createDownloadMessage({ status: 'running' }))
|
||||
await vi.advanceTimersByTimeAsync(9_999)
|
||||
dispatch(createDownloadMessage({ status: 'running', progress: 75 }))
|
||||
await vi.advanceTimersByTimeAsync(1)
|
||||
|
||||
expect(taskService.getTask).not.toHaveBeenCalled()
|
||||
expect(store.activeDownloads).toHaveLength(1)
|
||||
expect(store.activeDownloads[0].progress).toBe(75)
|
||||
})
|
||||
|
||||
it('polls and marks failed downloads', async () => {
|
||||
const store = useAssetDownloadStore()
|
||||
|
||||
@@ -311,5 +337,22 @@ describe('useAssetDownloadStore', () => {
|
||||
expect(store.sessionDownloadCount).toBe(0)
|
||||
expect(store.isDownloadedThisSession('asset-456')).toBe(false)
|
||||
})
|
||||
|
||||
it('does not acknowledge unrelated completed downloads', () => {
|
||||
const store = useAssetDownloadStore()
|
||||
|
||||
dispatch(
|
||||
createDownloadMessage({
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
asset_id: 'asset-456'
|
||||
})
|
||||
)
|
||||
|
||||
store.acknowledgeAsset('other-asset')
|
||||
|
||||
expect(store.sessionDownloadCount).toBe(1)
|
||||
expect(store.isDownloadedThisSession('asset-456')).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, watch } from 'vue'
|
||||
@@ -11,6 +12,7 @@ import type {
|
||||
} from '@/platform/assets/schemas/assetSchema'
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
||||
|
||||
// Mock the api module
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
@@ -96,6 +98,10 @@ const mockOutputOverrides = vi.hoisted(() => ({
|
||||
value: null as MockOutput[] | null
|
||||
}))
|
||||
|
||||
const mockAssetMapperOptions = vi.hoisted(() => ({
|
||||
omitCreatedAtForIds: new Set<string>()
|
||||
}))
|
||||
|
||||
// Mock TaskItemImpl
|
||||
const PREVIEWABLE_MEDIA_TYPES = new Set(['images', 'video', 'audio'])
|
||||
|
||||
@@ -169,11 +175,14 @@ vi.mock('@/platform/assets/composables/media/assetMappers', () => ({
|
||||
})),
|
||||
mapTaskOutputToAssetItem: vi.fn((task, output) => {
|
||||
const index = parseInt(task.jobId.split('_')[1]) || 0
|
||||
const createdAt = new Date(Date.now() - index * 1000).toISOString()
|
||||
return {
|
||||
id: task.jobId,
|
||||
name: output.filename,
|
||||
size: 0,
|
||||
created_at: new Date(Date.now() - index * 1000).toISOString(),
|
||||
...(!mockAssetMapperOptions.omitCreatedAtForIds.has(task.jobId) && {
|
||||
created_at: createdAt
|
||||
}),
|
||||
tags: ['output'],
|
||||
preview_url: output.url,
|
||||
user_metadata: {}
|
||||
@@ -205,6 +214,7 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useAssetsStore()
|
||||
vi.clearAllMocks()
|
||||
mockAssetMapperOptions.omitCreatedAtForIds.clear()
|
||||
})
|
||||
|
||||
describe('Initial Load', () => {
|
||||
@@ -272,6 +282,17 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
'prompt_2'
|
||||
])
|
||||
})
|
||||
|
||||
it('should skip unfinished jobs and completed jobs without previews', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValue([
|
||||
{ ...createMockJobItem(0), status: 'in_progress' },
|
||||
{ ...createMockJobItem(1), preview_output: undefined }
|
||||
])
|
||||
|
||||
await store.updateHistory()
|
||||
|
||||
expect(store.historyAssets).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Pagination', () => {
|
||||
@@ -328,6 +349,46 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
expect(uniqueAssetIds.size).toBe(store.historyAssets.length)
|
||||
})
|
||||
|
||||
it('should insert newer paginated items in sorted order', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce(
|
||||
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
|
||||
)
|
||||
await store.updateHistory()
|
||||
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(-1)])
|
||||
await store.loadMoreHistory()
|
||||
|
||||
expect(store.historyAssets[0].id).toBe('prompt_-1')
|
||||
})
|
||||
|
||||
it('sorts paginated items when the incoming asset has no timestamp', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce(
|
||||
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
|
||||
)
|
||||
await store.updateHistory()
|
||||
mockAssetMapperOptions.omitCreatedAtForIds.add('prompt_200')
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(200)])
|
||||
|
||||
await store.loadMoreHistory()
|
||||
|
||||
expect(store.historyAssets.at(-1)?.id).toBe('prompt_200')
|
||||
})
|
||||
|
||||
it('sorts paginated items when an existing asset has no timestamp', async () => {
|
||||
for (let i = 0; i < 200; i++) {
|
||||
mockAssetMapperOptions.omitCreatedAtForIds.add(`prompt_${i}`)
|
||||
}
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce(
|
||||
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
|
||||
)
|
||||
await store.updateHistory()
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(-1)])
|
||||
|
||||
await store.loadMoreHistory()
|
||||
|
||||
expect(store.historyAssets[0].id).toBe('prompt_-1')
|
||||
})
|
||||
|
||||
it('should stop loading when no more items', async () => {
|
||||
// First batch - less than BATCH_SIZE
|
||||
const firstBatch = Array.from({ length: 50 }, (_, i) =>
|
||||
@@ -494,6 +555,29 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
expect(store.historyLoading).toBe(false)
|
||||
expect(store.historyError).toBe(error)
|
||||
})
|
||||
|
||||
it('should preserve existing history when refresh fails', async () => {
|
||||
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(0)])
|
||||
await store.updateHistory()
|
||||
|
||||
const error = new Error('API error')
|
||||
vi.mocked(api.getHistory).mockRejectedValueOnce(error)
|
||||
|
||||
await store.updateHistory()
|
||||
|
||||
expect(store.historyAssets).toHaveLength(1)
|
||||
expect(store.historyError).toBe(error)
|
||||
})
|
||||
|
||||
it('should keep empty history when loadMore fails before any load', async () => {
|
||||
const error = new Error('API error')
|
||||
vi.mocked(api.getHistory).mockRejectedValueOnce(error)
|
||||
|
||||
await store.loadMoreHistory()
|
||||
|
||||
expect(store.historyAssets).toEqual([])
|
||||
expect(store.historyError).toBe(error)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Memory Management', () => {
|
||||
@@ -924,6 +1008,43 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
vi.mocked(assetService.getAssetsForNodeType)
|
||||
).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('ignores a model response after the category is invalidated', async () => {
|
||||
const store = useAssetsStore()
|
||||
let resolveFetch!: (assets: AssetItem[]) => void
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockReturnValueOnce(
|
||||
new Promise((resolve) => {
|
||||
resolveFetch = resolve
|
||||
})
|
||||
)
|
||||
|
||||
const request = store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
store.invalidateCategory('checkpoints')
|
||||
resolveFetch([createMockAsset('stale-response')])
|
||||
await request
|
||||
|
||||
expect(store.getAssets('CheckpointLoaderSimple')).toEqual([])
|
||||
})
|
||||
|
||||
it('ignores a model rejection after the category is invalidated', async () => {
|
||||
const store = useAssetsStore()
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
let rejectFetch!: (error: Error) => void
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockReturnValueOnce(
|
||||
new Promise((_resolve, reject) => {
|
||||
rejectFetch = reject
|
||||
})
|
||||
)
|
||||
|
||||
const request = store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
store.invalidateCategory('checkpoints')
|
||||
rejectFetch(new Error('stale rejection'))
|
||||
await request
|
||||
|
||||
expect(store.getError('CheckpointLoaderSimple')).toBeUndefined()
|
||||
expect(consoleSpy).not.toHaveBeenCalled()
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('shallowReactive state reactivity', () => {
|
||||
@@ -966,6 +1087,10 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
it('should return empty array for unknown node types', () => {
|
||||
const store = useAssetsStore()
|
||||
expect(store.getAssets('UnknownNodeType')).toEqual([])
|
||||
expect(store.isModelLoading('UnknownNodeType')).toBe(false)
|
||||
expect(store.getError('UnknownNodeType')).toBeUndefined()
|
||||
expect(store.hasMore('UnknownNodeType')).toBe(false)
|
||||
expect(store.hasAssetKey('UnknownNodeType')).toBe(false)
|
||||
})
|
||||
|
||||
it('should not fetch for unknown node types', async () => {
|
||||
@@ -975,6 +1100,63 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
vi.mocked(assetService.getAssetsForNodeType)
|
||||
).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should refresh an already loaded category', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('first')
|
||||
])
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('second')
|
||||
])
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
|
||||
expect(store.getAssets(nodeType).map((asset) => asset.id)).toEqual([
|
||||
'second'
|
||||
])
|
||||
})
|
||||
|
||||
it('reports hasMore for a loaded category', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
|
||||
expect(store.hasMore(nodeType)).toBe(false)
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('only-page')
|
||||
])
|
||||
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
|
||||
expect(store.hasMore(nodeType)).toBe(false)
|
||||
})
|
||||
|
||||
it('should record model loading errors', async () => {
|
||||
const store = useAssetsStore()
|
||||
const error = new Error('model fetch failed')
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockRejectedValueOnce(error)
|
||||
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
|
||||
expect(store.getError('CheckpointLoaderSimple')).toBe(error)
|
||||
expect(store.isModelLoading('CheckpointLoaderSimple')).toBe(false)
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should wrap non-error model loading failures', async () => {
|
||||
const store = useAssetsStore()
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockRejectedValueOnce('boom')
|
||||
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
|
||||
expect(store.getError('CheckpointLoaderSimple')?.message).toBe('boom')
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('invalidateCategory', () => {
|
||||
@@ -1129,7 +1311,140 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('completed download refresh', () => {
|
||||
it('refreshes provider and tag caches for the completed model type', async () => {
|
||||
const store = useAssetsStore()
|
||||
const downloadStore = useAssetDownloadStore()
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
|
||||
vi.mocked(assetService.getAssetsByTag).mockResolvedValue([])
|
||||
|
||||
downloadStore.lastCompletedDownload = {
|
||||
taskId: 'task-1',
|
||||
modelType: 'checkpoints',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() =>
|
||||
expect(assetService.getAssetsByTag).toHaveBeenCalledWith(
|
||||
'models',
|
||||
true,
|
||||
expect.objectContaining({ limit: 500, offset: 0 })
|
||||
)
|
||||
)
|
||||
|
||||
expect(assetService.getAssetsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
expect.objectContaining({ limit: 500, offset: 0 })
|
||||
)
|
||||
expect(assetService.getAssetsForNodeType).toHaveBeenCalledTimes(1)
|
||||
expect(assetService.getAssetsByTag).toHaveBeenCalledWith(
|
||||
'checkpoints',
|
||||
true,
|
||||
expect.objectContaining({ limit: 500, offset: 0 })
|
||||
)
|
||||
expect(store.hasCategory('tag:models')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateAssetMetadata optimistic cache', () => {
|
||||
it('still writes metadata when a cache key is unresolved', async () => {
|
||||
const store = useAssetsStore()
|
||||
const original = {
|
||||
...createMockAsset('opt-unknown'),
|
||||
user_metadata: { note: 'before' } as Record<string, unknown>
|
||||
}
|
||||
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
|
||||
...original,
|
||||
user_metadata: { note: 'after' }
|
||||
})
|
||||
|
||||
await store.updateAssetMetadata(
|
||||
original,
|
||||
{ note: 'after' },
|
||||
'UnknownNodeType'
|
||||
)
|
||||
|
||||
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
|
||||
'opt-unknown',
|
||||
{ user_metadata: { note: 'after' } }
|
||||
)
|
||||
})
|
||||
|
||||
it('still updates the server when the asset is not cached', async () => {
|
||||
const store = useAssetsStore()
|
||||
const original = {
|
||||
...createMockAsset('opt-missing'),
|
||||
user_metadata: { note: 'before' } as Record<string, unknown>
|
||||
}
|
||||
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
|
||||
...original,
|
||||
user_metadata: { note: 'server' }
|
||||
})
|
||||
|
||||
await store.updateAssetMetadata(original, { note: 'after' })
|
||||
|
||||
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
|
||||
'opt-missing',
|
||||
{ user_metadata: { note: 'after' } }
|
||||
)
|
||||
})
|
||||
|
||||
it('still updates the server when a resolved cache key has not loaded yet', async () => {
|
||||
const store = useAssetsStore()
|
||||
const original = {
|
||||
...createMockAsset('opt-unloaded'),
|
||||
user_metadata: { note: 'before' } as Record<string, unknown>
|
||||
}
|
||||
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
|
||||
...original,
|
||||
user_metadata: { note: 'server' }
|
||||
})
|
||||
|
||||
await store.updateAssetMetadata(
|
||||
original,
|
||||
{ note: 'after' },
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
|
||||
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
|
||||
'opt-unloaded',
|
||||
{ user_metadata: { note: 'after' } }
|
||||
)
|
||||
})
|
||||
|
||||
it('leaves unrelated cached assets alone during optimistic metadata update', async () => {
|
||||
const store = useAssetsStore()
|
||||
const cached = {
|
||||
...createMockAsset('opt-cached'),
|
||||
user_metadata: { note: 'cached' } as Record<string, unknown>
|
||||
}
|
||||
const missing = {
|
||||
...createMockAsset('opt-missing-from-cache'),
|
||||
user_metadata: { note: 'before' } as Record<string, unknown>
|
||||
}
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
cached
|
||||
])
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
|
||||
...missing,
|
||||
user_metadata: { note: 'server' }
|
||||
})
|
||||
|
||||
await store.updateAssetMetadata(
|
||||
missing,
|
||||
{ note: 'after' },
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
|
||||
expect(
|
||||
store.getAssets('CheckpointLoaderSimple')[0].user_metadata
|
||||
).toEqual({
|
||||
note: 'cached'
|
||||
})
|
||||
})
|
||||
|
||||
it('reflects the server response in the cache after a successful update', async () => {
|
||||
const store = useAssetsStore()
|
||||
const original = {
|
||||
@@ -1237,6 +1552,31 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
'featured'
|
||||
])
|
||||
})
|
||||
|
||||
it('calls only the remove endpoint when there are no tags to add', async () => {
|
||||
const store = useAssetsStore()
|
||||
const asset = createMockAsset('tags-remove-only', ['models', 'archived'])
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
asset
|
||||
])
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
|
||||
vi.mocked(assetService.removeAssetTags).mockResolvedValueOnce({
|
||||
total_tags: ['models']
|
||||
})
|
||||
|
||||
await store.updateAssetTags(asset, ['models'], 'CheckpointLoaderSimple')
|
||||
|
||||
expect(vi.mocked(assetService.removeAssetTags)).toHaveBeenCalledWith(
|
||||
'tags-remove-only',
|
||||
['archived']
|
||||
)
|
||||
expect(vi.mocked(assetService.addAssetTags)).not.toHaveBeenCalled()
|
||||
expect(store.getAssets('CheckpointLoaderSimple')[0].tags).toEqual([
|
||||
'models'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateAssetTags partial-failure compensation', () => {
|
||||
@@ -1351,6 +1691,36 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
expect(store.hasCategory('tag:models')).toBe(false)
|
||||
})
|
||||
|
||||
it('keeps unrelated tag caches when compensation fails with a cache key', async () => {
|
||||
const store = useAssetsStore()
|
||||
const asset = createMockAsset('tags-target-fail', ['models', 'loras'])
|
||||
const otherAsset = createMockAsset('tags-other', ['models'])
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
asset
|
||||
])
|
||||
await store.updateModelsForNodeType('LoraLoader')
|
||||
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([otherAsset])
|
||||
await store.updateModelsForTag('models')
|
||||
|
||||
vi.mocked(assetService.removeAssetTags).mockResolvedValueOnce({
|
||||
removed: ['loras'],
|
||||
total_tags: ['models']
|
||||
})
|
||||
vi.mocked(assetService.addAssetTags)
|
||||
.mockRejectedValueOnce(new Error('500 add failed'))
|
||||
.mockRejectedValueOnce(new Error('503 compensation failed'))
|
||||
|
||||
await store.updateAssetTags(
|
||||
asset,
|
||||
['models', 'checkpoints'],
|
||||
'LoraLoader'
|
||||
)
|
||||
|
||||
expect(store.hasCategory('loras')).toBe(false)
|
||||
expect(store.hasCategory('tag:models')).toBe(true)
|
||||
})
|
||||
|
||||
it('does not attempt compensation when only the add was attempted', async () => {
|
||||
const store = useAssetsStore()
|
||||
const asset = createMockAsset('tags-add-only-fail', ['models'])
|
||||
@@ -1483,9 +1853,78 @@ describe('assetsStore - Deletion State and Input Mapping', () => {
|
||||
const store = useAssetsStore()
|
||||
expect(store.getInputName('unknown.png')).toBe('unknown.png')
|
||||
})
|
||||
|
||||
it('ignores input assets without hashes', async () => {
|
||||
mockIsCloud.value = true
|
||||
try {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
const store = useAssetsStore()
|
||||
|
||||
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([
|
||||
{
|
||||
id: 'input-1',
|
||||
name: 'plain.png',
|
||||
tags: ['input']
|
||||
}
|
||||
])
|
||||
await store.updateInputs()
|
||||
|
||||
expect(store.getInputName('plain.png')).toBe('plain.png')
|
||||
} finally {
|
||||
mockIsCloud.value = false
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('updateInputs cloud routing', () => {
|
||||
it('reads input files from the internal API when isCloud is false', async () => {
|
||||
const fetchMock = vi.fn().mockResolvedValue(
|
||||
fromAny<Response, unknown>({
|
||||
ok: true,
|
||||
json: async () => ['input-a.png', 'input-b.png']
|
||||
})
|
||||
)
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
try {
|
||||
const store = useAssetsStore()
|
||||
|
||||
await store.updateInputs()
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
'http://localhost:3000/files/input',
|
||||
{ headers: { 'Comfy-User': 'test-user' } }
|
||||
)
|
||||
expect(store.inputAssets.map((asset) => asset.name)).toEqual([
|
||||
'input-a.png',
|
||||
'input-b.png'
|
||||
])
|
||||
} finally {
|
||||
vi.unstubAllGlobals()
|
||||
}
|
||||
})
|
||||
|
||||
it('records internal input API failures', async () => {
|
||||
const fetchMock = vi.fn().mockResolvedValue(
|
||||
fromAny<Response, unknown>({
|
||||
ok: false
|
||||
})
|
||||
)
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
try {
|
||||
const consoleSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
const store = useAssetsStore()
|
||||
|
||||
await store.updateInputs()
|
||||
|
||||
expect(store.inputError).toBeInstanceOf(Error)
|
||||
consoleSpy.mockRestore()
|
||||
} finally {
|
||||
vi.unstubAllGlobals()
|
||||
}
|
||||
})
|
||||
|
||||
it('reads from assetService.getAssetsByTag with limit 100 when isCloud is true', async () => {
|
||||
mockIsCloud.value = true
|
||||
try {
|
||||
@@ -1586,6 +2025,18 @@ describe('assetsStore - Flat Output Assets (cloud-only)', () => {
|
||||
expect(store.flatOutputHasMore).toBe(false)
|
||||
})
|
||||
|
||||
it('does not load more flat outputs when there are no more pages', async () => {
|
||||
vi.mocked(assetService.getAssetsPageByTag).mockResolvedValueOnce(
|
||||
makePage([makeAsset('a1', 'one.png')])
|
||||
)
|
||||
|
||||
const store = useAssetsStore()
|
||||
await store.updateFlatOutputs()
|
||||
await store.loadMoreFlatOutputs()
|
||||
|
||||
expect(assetService.getAssetsPageByTag).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('threads the minted cursor into after on loadMore and omits offset', async () => {
|
||||
vi.mocked(assetService.getAssetsPageByTag)
|
||||
.mockResolvedValueOnce(
|
||||
@@ -1800,4 +2251,26 @@ describe('assetsStore - Flat Output Assets (cloud-only)', () => {
|
||||
|
||||
expect(store.flatOutputAssets.map((x) => x.id)).toEqual(['shared-1'])
|
||||
})
|
||||
|
||||
it('ignores concurrent load more calls while one is active', async () => {
|
||||
vi.mocked(assetService.getAssetsPageByTag).mockResolvedValueOnce(
|
||||
makePage([makeAsset('a1', 'f1.png')], { hasMore: true })
|
||||
)
|
||||
const store = useAssetsStore()
|
||||
await store.updateFlatOutputs()
|
||||
|
||||
let resolvePage!: (page: AssetResponse) => void
|
||||
vi.mocked(assetService.getAssetsPageByTag).mockReturnValueOnce(
|
||||
new Promise<AssetResponse>((resolve) => {
|
||||
resolvePage = resolve
|
||||
})
|
||||
)
|
||||
|
||||
const first = store.loadMoreFlatOutputs()
|
||||
const second = store.loadMoreFlatOutputs()
|
||||
resolvePage(makePage([makeAsset('a2', 'f2.png')]))
|
||||
await Promise.all([first, second])
|
||||
|
||||
expect(assetService.getAssetsPageByTag).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
@@ -177,9 +178,10 @@ describe('useComfyRegistryStore', () => {
|
||||
|
||||
it('should return null when fetching a pack with null ID', async () => {
|
||||
const store = useComfyRegistryStore()
|
||||
vi.spyOn(store.getPackById, 'call').mockResolvedValueOnce(null)
|
||||
|
||||
const result = await store.getPackById.call(null!)
|
||||
const result = await store.getPackById.call(
|
||||
fromAny<Parameters<typeof store.getPackById.call>[0], unknown>(null)
|
||||
)
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(mockRegistryService.getPackById).not.toHaveBeenCalled()
|
||||
@@ -206,6 +208,56 @@ describe('useComfyRegistryStore', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('should reuse cached packs by ID', async () => {
|
||||
const store = useComfyRegistryStore()
|
||||
|
||||
await store.getPacksByIds.call(['test-pack-id'])
|
||||
const result = await store.getPacksByIds.call(['test-pack-id'])
|
||||
|
||||
expect(result).toEqual([mockNodePack])
|
||||
expect(mockRegistryService.listAllPacks).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should ignore missing packs by ID', async () => {
|
||||
mockRegistryService.listAllPacks.mockResolvedValueOnce({
|
||||
nodes: [fromAny<components['schemas']['Node'], unknown>({ name: 'bad' })],
|
||||
total: 1,
|
||||
page: 1,
|
||||
limit: 10
|
||||
})
|
||||
const store = useComfyRegistryStore()
|
||||
|
||||
const result = await store.getPacksByIds.call(['unknown-pack-id'])
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle empty pack lookup responses', async () => {
|
||||
mockRegistryService.listAllPacks.mockResolvedValueOnce(null)
|
||||
const store = useComfyRegistryStore()
|
||||
|
||||
const result = await store.getPacksByIds.call(['unknown-pack-id'])
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
it('should filter undefined pack IDs before lookup', async () => {
|
||||
const store = useComfyRegistryStore()
|
||||
|
||||
const result = await store.getPacksByIds.call(
|
||||
fromAny<components['schemas']['Node']['id'][], unknown>([
|
||||
'test-pack-id',
|
||||
undefined
|
||||
])
|
||||
)
|
||||
|
||||
expect(result).toEqual([mockNodePack])
|
||||
expect(mockRegistryService.listAllPacks).toHaveBeenCalledWith(
|
||||
{ node_id: ['test-pack-id'] },
|
||||
expect.any(Object)
|
||||
)
|
||||
})
|
||||
|
||||
describe('inferPackFromNodeName', () => {
|
||||
it('should fetch a pack by comfy node name', async () => {
|
||||
const store = useComfyRegistryStore()
|
||||
|
||||
@@ -137,6 +137,88 @@ describe('useModelStore', () => {
|
||||
expect(model.resolution).toBe('')
|
||||
})
|
||||
|
||||
it('keeps the default model metadata when the server returns null', async () => {
|
||||
enableMocks()
|
||||
vi.mocked(api.viewMetadata).mockResolvedValueOnce(null)
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
const folderStore = await store.getLoadedModelFolder('checkpoints')
|
||||
const model = folderStore!.models['0/sdxl.safetensors']
|
||||
|
||||
await model.load()
|
||||
|
||||
expect(model.title).toBe('sdxl')
|
||||
expect(model.has_loaded_metadata).toBe(false)
|
||||
})
|
||||
|
||||
it('loads model metadata once', async () => {
|
||||
enableMocks()
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
const folderStore = await store.getLoadedModelFolder('checkpoints')
|
||||
const model = folderStore!.models['0/sdxl.safetensors']
|
||||
|
||||
await model.load()
|
||||
await model.load()
|
||||
|
||||
expect(api.viewMetadata).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('keeps the default title when the first metadata key is empty', async () => {
|
||||
enableMocks()
|
||||
vi.mocked(api.viewMetadata).mockResolvedValueOnce({
|
||||
'modelspec.title': '',
|
||||
display_name: 'Fallback title'
|
||||
})
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
const folderStore = await store.getLoadedModelFolder('checkpoints')
|
||||
const model = folderStore!.models['0/sdxl.safetensors']
|
||||
|
||||
await model.load()
|
||||
|
||||
expect(model.title).toBe('sdxl')
|
||||
})
|
||||
|
||||
it('returns null for unknown loaded model folders', async () => {
|
||||
enableMocks()
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
|
||||
await expect(store.getLoadedModelFolder('missing')).resolves.toBeNull()
|
||||
})
|
||||
|
||||
it('should read metadata from suffixed keys and ignore null values', async () => {
|
||||
enableMocks()
|
||||
vi.mocked(api.viewMetadata).mockResolvedValueOnce({
|
||||
'custom.modelspec.title': 'Namespaced title',
|
||||
'custom.modelspec.author': null,
|
||||
'custom.modelspec.tags': null
|
||||
})
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
const folderStore = await store.getLoadedModelFolder('checkpoints')
|
||||
const model = folderStore!.models['0/sdxl.safetensors']
|
||||
|
||||
await model.load()
|
||||
|
||||
expect(model.title).toBe('Namespaced title')
|
||||
expect(model.author).toBe('')
|
||||
expect(model.tags).toEqual([''])
|
||||
})
|
||||
|
||||
it('should keep extensions for non-safetensors files', async () => {
|
||||
enableMocks()
|
||||
vi.mocked(api.getModels).mockResolvedValueOnce([
|
||||
{ name: 'notes.txt', pathIndex: 0 }
|
||||
])
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
const folderStore = await store.getLoadedModelFolder('checkpoints')
|
||||
|
||||
expect(folderStore!.models['0/notes.txt'].title).toBe('notes.txt')
|
||||
})
|
||||
|
||||
it('should cache model information', async () => {
|
||||
enableMocks()
|
||||
store = useModelStore()
|
||||
@@ -209,6 +291,23 @@ describe('useModelStore', () => {
|
||||
expect(api.getModelFolders).toHaveBeenCalledTimes(2)
|
||||
expect(api.getModels).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not reload previously loaded folders that disappear', async () => {
|
||||
enableMocks()
|
||||
store = useModelStore()
|
||||
await store.loadModelFolders()
|
||||
await store.getLoadedModelFolder('checkpoints')
|
||||
vi.mocked(api.getModelFolders).mockResolvedValueOnce([
|
||||
{ name: 'vae', folders: ['/path/to/vae'] }
|
||||
])
|
||||
|
||||
await store.refresh()
|
||||
|
||||
expect(store.modelFolders.map((folder) => folder.directory)).toEqual([
|
||||
'vae'
|
||||
])
|
||||
expect(api.getModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('API switching functionality', () => {
|
||||
|
||||
@@ -138,6 +138,22 @@ describe('useModelToNodeStore', () => {
|
||||
expect(provider?.key).toBe('ckpt_name')
|
||||
})
|
||||
|
||||
it('omits providers whose node definition is unavailable from reverse lookup', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.modelToNodeMap = {
|
||||
missing: [
|
||||
new ModelNodeProvider(
|
||||
undefined as unknown as ComfyNodeDefImpl,
|
||||
'model'
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
expect(modelToNodeStore.getRegisteredNodeTypes()).not.toHaveProperty(
|
||||
'undefined'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return undefined for unregistered model type', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
@@ -577,6 +593,22 @@ describe('useModelToNodeStore', () => {
|
||||
expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('skips providers without node definitions during category lookup', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.modelToNodeMap = {
|
||||
missing: [
|
||||
new ModelNodeProvider(
|
||||
undefined as unknown as ComfyNodeDefImpl,
|
||||
'model'
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
expect(
|
||||
modelToNodeStore.getCategoryForNodeType('MissingNode')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('maps the IC-LoRA Loader Model Only node to loras so its lora_name dropdown uses the cloud asset browser (FE-838)', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { fromAny } from '@total-typescript/shoehorn'
|
||||
import axios from 'axios'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { promoteValueWidgetViaSubgraphInput } from '@/core/graph/subgraph/promotionUtils'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraphNode, LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LGraph, SubgraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import {
|
||||
createTestSubgraph,
|
||||
createTestSubgraphNode
|
||||
} from '@/lib/litegraph/src/subgraph/__fixtures__/subgraphHelpers'
|
||||
import type { ComfyNodeDef } from '@/schemas/nodeDefSchema'
|
||||
import { useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
import {
|
||||
ComfyNodeDefImpl,
|
||||
buildNodeDefTree,
|
||||
createDummyFolderNodeDef,
|
||||
useNodeDefStore,
|
||||
useNodeFrequencyStore
|
||||
} from '@/stores/nodeDefStore'
|
||||
import type { NodeDefFilter } from '@/stores/nodeDefStore'
|
||||
|
||||
describe('useNodeDefStore', () => {
|
||||
@@ -21,6 +29,10 @@ describe('useNodeDefStore', () => {
|
||||
store = useNodeDefStore()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
const createMockNodeDef = (
|
||||
overrides: Partial<ComfyNodeDef> = {}
|
||||
): ComfyNodeDef => ({
|
||||
@@ -39,7 +51,112 @@ describe('useNodeDefStore', () => {
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('ComfyNodeDefImpl', () => {
|
||||
it('migrates defaultInput options and applies constructor fallbacks', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
const nodeDef = createMockNodeDef({
|
||||
category: '_for_testing/coverage',
|
||||
deprecated: undefined,
|
||||
dev_only: undefined,
|
||||
experimental: undefined,
|
||||
help: undefined,
|
||||
input: {
|
||||
required: { prompt: ['STRING', { defaultInput: true }] },
|
||||
optional: { seed_override: ['INT', { defaultInput: true }] }
|
||||
}
|
||||
})
|
||||
|
||||
const impl = new ComfyNodeDefImpl(nodeDef)
|
||||
|
||||
expect(warn).toHaveBeenCalledTimes(2)
|
||||
expect(impl.help).toBe('')
|
||||
expect(impl.experimental).toBe(true)
|
||||
expect(impl.dev_only).toBe(false)
|
||||
expect(impl.inputs.seed_override.forceInput).toBe(true)
|
||||
})
|
||||
|
||||
it('derives empty-category node paths and lifecycle badges', () => {
|
||||
const deprecated = new ComfyNodeDefImpl(
|
||||
createMockNodeDef({ category: '', deprecated: undefined })
|
||||
)
|
||||
const beta = new ComfyNodeDefImpl(
|
||||
createMockNodeDef({ experimental: true })
|
||||
)
|
||||
const dev = new ComfyNodeDefImpl(createMockNodeDef({ dev_only: true }))
|
||||
const normal = new ComfyNodeDefImpl(createMockNodeDef())
|
||||
|
||||
expect(deprecated.nodePath).toBe('TestNode')
|
||||
expect(deprecated.isDummyFolder).toBe(false)
|
||||
expect(deprecated.nodeLifeCycleBadgeText).toBe('[DEPR]')
|
||||
expect(beta.nodeLifeCycleBadgeText).toBe('[BETA]')
|
||||
expect(dev.nodeLifeCycleBadgeText).toBe('[DEV]')
|
||||
expect(normal.nodeLifeCycleBadgeText).toBe('')
|
||||
})
|
||||
|
||||
it('defaults missing legacy input and output fields', () => {
|
||||
const nodeDef = new ComfyNodeDefImpl(
|
||||
fromAny<ComfyNodeDef, unknown>({
|
||||
name: 'FallbackNode',
|
||||
display_name: 'Fallback Node',
|
||||
category: 'test',
|
||||
python_module: 'test_module',
|
||||
description: 'Test node',
|
||||
output_node: false
|
||||
})
|
||||
)
|
||||
|
||||
expect(nodeDef.input).toEqual({})
|
||||
expect(nodeDef.output).toEqual([])
|
||||
})
|
||||
|
||||
it('post-processes search scores with node frequency', async () => {
|
||||
vi.spyOn(axios, 'get').mockResolvedValue({ data: { TestNode: 7 } })
|
||||
const frequencyStore = useNodeFrequencyStore()
|
||||
await frequencyStore.loadNodeFrequencies()
|
||||
const nodeDef = new ComfyNodeDefImpl(createMockNodeDef())
|
||||
|
||||
expect(nodeDef.postProcessSearchScores([10, 4, 2])).toEqual([
|
||||
10, -7, 4, 2
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('tree helpers', () => {
|
||||
it('builds node definition trees from default and custom paths', () => {
|
||||
const nodeDef = new ComfyNodeDefImpl(
|
||||
createMockNodeDef({ name: 'TreeNode', category: 'root/branch' })
|
||||
)
|
||||
|
||||
expect(buildNodeDefTree([nodeDef]).children?.[0].label).toBe('root')
|
||||
expect(
|
||||
buildNodeDefTree([nodeDef], {
|
||||
pathExtractor: (node) => ['custom', node.name]
|
||||
}).children?.[0].label
|
||||
).toBe('custom')
|
||||
})
|
||||
|
||||
it('normalizes dummy folder paths', () => {
|
||||
expect(createDummyFolderNodeDef('folder/').category).toBe('folder')
|
||||
expect(createDummyFolderNodeDef('folder').category).toBe('folder')
|
||||
})
|
||||
})
|
||||
|
||||
describe('filter registry', () => {
|
||||
it('updates LiteGraph skip state for registered dev-only nodes', () => {
|
||||
const registeredNodeTypes = LiteGraph.registered_node_types
|
||||
LiteGraph.registered_node_types = fromAny({
|
||||
DevNode: { nodeData: { dev_only: true }, skip_list: false },
|
||||
NormalNode: { nodeData: {}, skip_list: false }
|
||||
})
|
||||
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
useNodeDefStore()
|
||||
|
||||
expect(LiteGraph.registered_node_types.DevNode.skip_list).toBe(true)
|
||||
expect(LiteGraph.registered_node_types.NormalNode.skip_list).toBe(false)
|
||||
LiteGraph.registered_node_types = registeredNodeTypes
|
||||
})
|
||||
|
||||
it('should register a new filter', () => {
|
||||
const filter: NodeDefFilter = {
|
||||
id: 'test.filter',
|
||||
@@ -287,6 +404,26 @@ describe('useNodeDefStore', () => {
|
||||
})
|
||||
|
||||
describe('allNodeDefsByName', () => {
|
||||
it('keeps existing ComfyNodeDefImpl instances during updates', () => {
|
||||
const nodeDef = new ComfyNodeDefImpl(
|
||||
createMockNodeDef({ name: 'ExistingImpl' })
|
||||
)
|
||||
|
||||
store.updateNodeDefs([nodeDef])
|
||||
|
||||
expect(store.nodeDefsByName.ExistingImpl.name).toBe('ExistingImpl')
|
||||
expect(store.nodeDefsByDisplayName['Test Node'].name).toBe('ExistingImpl')
|
||||
})
|
||||
|
||||
it('adds one node definition to the name and display-name indexes', () => {
|
||||
store.addNodeDef(
|
||||
createMockNodeDef({ name: 'AddedNode', display_name: 'Added Node' })
|
||||
)
|
||||
|
||||
expect(store.nodeDefsByName.AddedNode.name).toBe('AddedNode')
|
||||
expect(store.nodeDefsByDisplayName['Added Node'].name).toBe('AddedNode')
|
||||
})
|
||||
|
||||
it('should include all node defs by name', () => {
|
||||
const node1 = createMockNodeDef({ name: 'Node1' })
|
||||
const node2 = createMockNodeDef({ name: 'Node2' })
|
||||
@@ -336,6 +473,39 @@ describe('useNodeDefStore', () => {
|
||||
expect(store.allNodeDefsByName).toHaveProperty('Normal')
|
||||
expect(store.allNodeDefsByName).toHaveProperty('Deprecated')
|
||||
})
|
||||
|
||||
it('derives unique input and output data types', () => {
|
||||
store.updateNodeDefs([
|
||||
createMockNodeDef({
|
||||
input: {
|
||||
required: { image: ['IMAGE', {}] },
|
||||
optional: { mask: ['MASK', {}] }
|
||||
},
|
||||
output: ['IMAGE', 'LATENT'],
|
||||
output_is_list: [false, false],
|
||||
output_name: ['image', 'latent']
|
||||
})
|
||||
])
|
||||
|
||||
expect([...store.nodeDataTypes].sort()).toEqual([
|
||||
'IMAGE',
|
||||
'LATENT',
|
||||
'MASK'
|
||||
])
|
||||
})
|
||||
|
||||
it('looks up node definitions from graph nodes and returns null for misses', () => {
|
||||
store.updateNodeDefs([createMockNodeDef({ name: 'KnownNode' })])
|
||||
|
||||
expect(
|
||||
store.fromLGraphNode(new LGraphNode('KnownNode', 'KnownNode'))?.name
|
||||
).toBe('KnownNode')
|
||||
expect(store.fromLGraphNode(new LGraphNode('', ''))).toBeNull()
|
||||
expect(
|
||||
store.getInputSpecForWidget(new LGraphNode('Missing', 'Missing'), 'x')
|
||||
).toBeUndefined()
|
||||
expect(store.nodeSearchService).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('subgraph widget input specs', () => {
|
||||
@@ -389,6 +559,94 @@ describe('useNodeDefStore', () => {
|
||||
expect(spec?.type).toBe('STRING')
|
||||
expect(spec?.default).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for missing promoted subgraph inputs', () => {
|
||||
const host = setupPromotedPrompt(
|
||||
createMockNodeDef({
|
||||
name: 'PromptNode',
|
||||
input: { required: { prompt: ['STRING', {}] } }
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.getInputSpecForWidget(host, 'missing')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when a subgraph input is not promoted', () => {
|
||||
const subgraph = createTestSubgraph()
|
||||
const host = createTestSubgraphNode(subgraph)
|
||||
host.addInput('raw', 'STRING')
|
||||
|
||||
expect(store.getInputSpecForWidget(host, 'raw')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when a promoted source no longer resolves', () => {
|
||||
const host = setupPromotedPrompt(
|
||||
createMockNodeDef({
|
||||
name: 'PromptNode',
|
||||
input: { required: { prompt: ['STRING', {}] } }
|
||||
})
|
||||
)
|
||||
host.subgraph.nodes[0].widgets = []
|
||||
|
||||
expect(store.getInputSpecForWidget(host, 'prompt')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when concrete promoted widget resolution fails', async () => {
|
||||
const resolver =
|
||||
await import('@/core/graph/subgraph/resolveConcretePromotedWidget')
|
||||
vi.spyOn(resolver, 'resolveConcretePromotedWidget').mockReturnValue(
|
||||
fromAny({ status: 'failure', failure: 'missing-widget' })
|
||||
)
|
||||
const host = setupPromotedPrompt(
|
||||
createMockNodeDef({
|
||||
name: 'PromptNode',
|
||||
input: { required: { prompt: ['STRING', {}] } }
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.getInputSpecForWidget(host, 'prompt')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('node frequency store', () => {
|
||||
it('loads frequencies once and exposes top matching node definitions', async () => {
|
||||
const get = vi.spyOn(axios, 'get').mockResolvedValue({
|
||||
data: { RankedNode: 10, MissingNode: 3 }
|
||||
})
|
||||
store.updateNodeDefs([createMockNodeDef({ name: 'RankedNode' })])
|
||||
const frequencyStore = useNodeFrequencyStore()
|
||||
|
||||
await frequencyStore.loadNodeFrequencies()
|
||||
await frequencyStore.loadNodeFrequencies()
|
||||
|
||||
expect(get).toHaveBeenCalledTimes(1)
|
||||
expect(frequencyStore.isLoaded).toBe(true)
|
||||
expect(frequencyStore.getNodeFrequencyByName('RankedNode')).toBe(10)
|
||||
expect(
|
||||
frequencyStore.getNodeFrequency(
|
||||
new ComfyNodeDefImpl(createMockNodeDef({ name: 'RankedNode' }))
|
||||
)
|
||||
).toBe(10)
|
||||
expect(frequencyStore.getNodeFrequencyByName('Unknown')).toBe(0)
|
||||
expect(frequencyStore.topNodeDefs.map((nodeDef) => nodeDef.name)).toEqual(
|
||||
['RankedNode']
|
||||
)
|
||||
})
|
||||
|
||||
it('leaves frequency state unloaded when loading fails', async () => {
|
||||
const error = new Error('boom')
|
||||
vi.spyOn(axios, 'get').mockRejectedValue(error)
|
||||
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
const frequencyStore = useNodeFrequencyStore()
|
||||
|
||||
await frequencyStore.loadNodeFrequencies()
|
||||
|
||||
expect(frequencyStore.isLoaded).toBe(false)
|
||||
expect(errorSpy).toHaveBeenCalledWith(
|
||||
'Error loading node frequencies:',
|
||||
error
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('performance', () => {
|
||||
|
||||
@@ -3,15 +3,41 @@ import { fromAny } from '@total-typescript/shoehorn'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LGraphNode, SubgraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { ExecutedWsMessage } from '@/schemas/apiSchema'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useNodeOutputStore } from '@/stores/nodeOutputStore'
|
||||
import { createNodeExecutionId } from '@/types/nodeIdentification'
|
||||
import {
|
||||
createNodeExecutionId,
|
||||
createNodeLocatorId
|
||||
} from '@/types/nodeIdentification'
|
||||
import type { NodeExecutionId, NodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
import * as litegraphUtil from '@/utils/litegraphUtil'
|
||||
|
||||
const {
|
||||
mockApiURL,
|
||||
mockExecutionIdToNodeLocatorId,
|
||||
mockNodeIdToNodeLocatorId,
|
||||
mockNodeToNodeLocatorId,
|
||||
mockReleaseSharedObjectUrl,
|
||||
mockRetainSharedObjectUrl
|
||||
} = vi.hoisted(() => ({
|
||||
mockApiURL: vi.fn((path: string) => `api${path}`),
|
||||
mockExecutionIdToNodeLocatorId: vi.fn(
|
||||
(_rootGraph: unknown, id: NodeExecutionId) => id as unknown as NodeLocatorId
|
||||
),
|
||||
mockNodeIdToNodeLocatorId: vi.fn(
|
||||
(id: string | number) => String(id) as NodeLocatorId
|
||||
),
|
||||
mockNodeToNodeLocatorId: vi.fn(
|
||||
(node: { id: string | number }) => String(node.id) as NodeLocatorId
|
||||
),
|
||||
mockReleaseSharedObjectUrl: vi.fn(),
|
||||
mockRetainSharedObjectUrl: vi.fn()
|
||||
}))
|
||||
|
||||
const mockResolveNode = vi.fn()
|
||||
|
||||
vi.mock('@/utils/litegraphUtil', () => ({
|
||||
@@ -20,11 +46,25 @@ vi.mock('@/utils/litegraphUtil', () => ({
|
||||
resolveNode: (...args: unknown[]) => mockResolveNode(...args)
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
apiURL: (...args: Parameters<typeof mockApiURL>) => mockApiURL(...args)
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/objectUrlUtil', () => ({
|
||||
releaseSharedObjectUrl: (...args: [string | undefined]) =>
|
||||
mockReleaseSharedObjectUrl(...args),
|
||||
retainSharedObjectUrl: (...args: [string | undefined]) =>
|
||||
mockRetainSharedObjectUrl(...args)
|
||||
}))
|
||||
|
||||
const mockGetNodeById = vi.fn()
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
getPreviewFormatParam: vi.fn(() => '&format=test_webp'),
|
||||
getRandParam: vi.fn(() => '&rand=1'),
|
||||
rootGraph: {
|
||||
getNodeById: (...args: unknown[]) => mockGetNodeById(...args)
|
||||
},
|
||||
@@ -49,13 +89,31 @@ const createMockOutputs = (
|
||||
): ExecutedWsMessage['output'] => ({ images })
|
||||
|
||||
vi.mock('@/utils/graphTraversalUtil', () => ({
|
||||
executionIdToNodeLocatorId: vi.fn((_rootGraph: unknown, id: string) => id)
|
||||
executionIdToNodeLocatorId: (
|
||||
...args: Parameters<typeof mockExecutionIdToNodeLocatorId>
|
||||
) => mockExecutionIdToNodeLocatorId(...args)
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
mockExecutionIdToNodeLocatorId.mockImplementation(
|
||||
(_rootGraph: unknown, id: NodeExecutionId) => id as unknown as NodeLocatorId
|
||||
)
|
||||
mockNodeIdToNodeLocatorId.mockImplementation(
|
||||
(id: string | number) => String(id) as NodeLocatorId
|
||||
)
|
||||
mockNodeToNodeLocatorId.mockImplementation(
|
||||
(node: { id: string | number }) => String(node.id) as NodeLocatorId
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', () => ({
|
||||
useWorkflowStore: vi.fn(() => ({
|
||||
nodeIdToNodeLocatorId: vi.fn((id: string | number) => String(id)),
|
||||
nodeToNodeLocatorId: vi.fn((node: { id: number }) => String(node.id))
|
||||
nodeIdToNodeLocatorId: (
|
||||
...args: Parameters<typeof mockNodeIdToNodeLocatorId>
|
||||
) => mockNodeIdToNodeLocatorId(...args),
|
||||
nodeToNodeLocatorId: (
|
||||
...args: Parameters<typeof mockNodeToNodeLocatorId>
|
||||
) => mockNodeToNodeLocatorId(...args)
|
||||
}))
|
||||
}))
|
||||
|
||||
@@ -780,6 +838,19 @@ describe('nodeOutputStore setNodeOutputs (widget path)', () => {
|
||||
expect(store.nodeOutputs['5']?.images?.[0]?.type).toBe('input')
|
||||
})
|
||||
|
||||
it('ignores widget outputs when no locator can be resolved', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
mockNodeToNodeLocatorId.mockReturnValueOnce(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
store.setNodeOutputs(node, 'test.png')
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
expect(app.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('should skip empty array of filenames after createOutputs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
@@ -789,6 +860,470 @@ describe('nodeOutputStore setNodeOutputs (widget path)', () => {
|
||||
expect(store.nodeOutputs['5']).toBeUndefined()
|
||||
expect(app.nodeOutputs['5']).toBeUndefined()
|
||||
})
|
||||
|
||||
it('stores direct result items without wrapping them as image outputs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
|
||||
store.setNodeOutputs(node, { filename: 'direct.png', type: 'temp' })
|
||||
|
||||
expect(store.nodeOutputs['5']).toEqual({
|
||||
filename: 'direct.png',
|
||||
type: 'temp'
|
||||
})
|
||||
})
|
||||
|
||||
it('marks animated webp and png filenames when requested', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
|
||||
store.setNodeOutputs(node, ['clip.webp', 'still.jpg', 'mask.png'], {
|
||||
folder: 'output',
|
||||
isAnimated: true
|
||||
})
|
||||
|
||||
expect(store.nodeOutputs['5']?.animated).toEqual([true, false, true])
|
||||
expect(store.nodeOutputs['5']?.images?.map((image) => image.type)).toEqual([
|
||||
'output',
|
||||
'output',
|
||||
'output'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore image URLs', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(litegraphUtil.isAnimatedOutput).mockReturnValue(false)
|
||||
vi.mocked(litegraphUtil.isVideoNode).mockReturnValue(false)
|
||||
app.nodeOutputs = {}
|
||||
app.nodePreviewImages = {}
|
||||
})
|
||||
|
||||
it('returns stored preview URLs before output URLs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
|
||||
store.setNodePreviewsByLocatorId(createNodeLocatorId(null, toNodeId(5)), [
|
||||
'blob:preview'
|
||||
])
|
||||
|
||||
expect(store.getNodeImageUrls(node)).toEqual(['blob:preview'])
|
||||
expect(mockApiURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('builds view URLs from output images', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
app.nodeOutputs['5'] = createMockOutputs(
|
||||
fromAny([{ filename: 'a.png', subfolder: 'x', type: 'temp' }, null])
|
||||
)
|
||||
|
||||
expect(store.getNodeImageUrls(node)).toEqual([
|
||||
'api/view?filename=a.png&subfolder=x&type=temp&format=test_webp&rand=1'
|
||||
])
|
||||
})
|
||||
|
||||
it('returns undefined when a node has neither previews nor outputs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
expect(store.getNodeImageUrls(createMockNode({ id: 5 }))).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns execution previews before execution output URLs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodePreviewsByExecutionId(executionId, ['blob:preview'])
|
||||
|
||||
expect(store.getNodeImageUrlsByExecutionId(executionId, node)).toEqual([
|
||||
'blob:preview'
|
||||
])
|
||||
expect(store.latestPreview).toEqual(['blob:preview'])
|
||||
expect(mockApiURL).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('falls back to execution output URLs when no preview exists', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
createMockOutputs([{ filename: 'result.png', type: 'temp' }])
|
||||
)
|
||||
|
||||
expect(store.getNodeImageUrlsByExecutionId(executionId, node)).toEqual([
|
||||
'api/view?filename=result.png&type=temp&format=test_webp&rand=1'
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore locator misses', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
app.nodeOutputs = {}
|
||||
app.nodePreviewImages = {}
|
||||
})
|
||||
|
||||
it('keeps execution operations inert when no locator can be resolved', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValue(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
createMockOutputs([{ filename: 'result.png' }])
|
||||
)
|
||||
store.setNodePreviewsByExecutionId(executionId, ['blob:preview'])
|
||||
store.revokePreviewsByExecutionId(executionId)
|
||||
|
||||
expect(store.getNodeOutputByExecutionId(executionId)).toBeUndefined()
|
||||
expect(store.getNodePreviewImagesByExecutionId(executionId)).toBeUndefined()
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
expect(store.nodePreviewImages).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore merge branches', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
app.nodeOutputs = {}
|
||||
app.nodePreviewImages = {}
|
||||
})
|
||||
|
||||
it('sets outputs when merge is requested without existing output', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
const output = createMockOutputs([{ filename: 'first.png' }])
|
||||
|
||||
store.setNodeOutputsByExecutionId(executionId, output, { merge: true })
|
||||
|
||||
expect(store.nodeOutputs[executionId]).toEqual(output)
|
||||
})
|
||||
|
||||
it('ignores null outputs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
fromAny<ExecutedWsMessage['output'], unknown>(null)
|
||||
)
|
||||
|
||||
expect(store.nodeOutputs[executionId]).toBeUndefined()
|
||||
})
|
||||
|
||||
it('overwrites non-array fields during merge', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
const firstOutput: ExecutedWsMessage['output'] = {
|
||||
images: [{ filename: 'first.png' }],
|
||||
text: 'old'
|
||||
}
|
||||
|
||||
store.setNodeOutputsByExecutionId(executionId, firstOutput)
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
{ text: ['new'] },
|
||||
{ merge: true }
|
||||
)
|
||||
|
||||
expect(store.nodeOutputs[executionId]?.images).toEqual([
|
||||
{ filename: 'first.png' }
|
||||
])
|
||||
expect(store.nodeOutputs[executionId]?.text).toEqual(['new'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore previews and removal', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
app.nodeOutputs = {}
|
||||
app.nodePreviewImages = {}
|
||||
})
|
||||
|
||||
it('releases old previews and retains new previews on replacement', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const locatorId = createNodeLocatorId(null, toNodeId(5))
|
||||
|
||||
store.setNodePreviewsByLocatorId(locatorId, ['blob:first'])
|
||||
store.setNodePreviewsByLocatorId(locatorId, ['blob:second'])
|
||||
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:first')
|
||||
expect(mockRetainSharedObjectUrl).toHaveBeenCalledWith('blob:second')
|
||||
expect(store.nodePreviewImages[locatorId]).toEqual(['blob:second'])
|
||||
})
|
||||
|
||||
it('starts with an empty preview map when legacy previews are missing', () => {
|
||||
app.nodePreviewImages = fromAny(undefined)
|
||||
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
expect(store.nodePreviewImages).toEqual({})
|
||||
})
|
||||
|
||||
it('cancels scheduled revocation when a newer preview arrives', async () => {
|
||||
vi.useFakeTimers()
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodePreviewsByExecutionId(executionId, ['blob:first'])
|
||||
store.revokePreviewsByExecutionId(executionId)
|
||||
store.setNodePreviewsByExecutionId(executionId, ['blob:second'])
|
||||
await vi.advanceTimersByTimeAsync(400)
|
||||
vi.useRealTimers()
|
||||
|
||||
expect(store.nodePreviewImages[executionId]).toEqual(['blob:second'])
|
||||
expect(mockReleaseSharedObjectUrl).not.toHaveBeenCalledWith('blob:second')
|
||||
})
|
||||
|
||||
it('revokes locator previews and clears preview state', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const locatorId = createNodeLocatorId(null, toNodeId(5))
|
||||
|
||||
store.setNodePreviewsByLocatorId(locatorId, ['blob:first'])
|
||||
store.revokePreviewsByLocatorId(locatorId)
|
||||
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:first')
|
||||
expect(store.nodePreviewImages[locatorId]).toBeUndefined()
|
||||
expect(app.nodePreviewImages[locatorId]).toBeUndefined()
|
||||
})
|
||||
|
||||
it('leaves state unchanged when revoking a locator with no previews', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
store.revokePreviewsByLocatorId(createNodeLocatorId(null, toNodeId(5)))
|
||||
|
||||
expect(mockReleaseSharedObjectUrl).not.toHaveBeenCalled()
|
||||
expect(store.nodePreviewImages).toEqual({})
|
||||
})
|
||||
|
||||
it('skips non-iterable preview entries when revoking all previews', () => {
|
||||
const store = useNodeOutputStore()
|
||||
app.nodePreviewImages = fromAny({
|
||||
'5': {},
|
||||
'6': ['blob:preview']
|
||||
})
|
||||
|
||||
store.revokeAllPreviews()
|
||||
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledTimes(1)
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:preview')
|
||||
expect(store.nodePreviewImages).toEqual({})
|
||||
})
|
||||
|
||||
it('revokes subgraph previews for the parent node and child nodes', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const subgraphId = '11111111-1111-1111-1111-111111111111'
|
||||
const parentLocatorId = createNodeLocatorId(null, toNodeId(9))
|
||||
const childLocatorId = createNodeLocatorId(subgraphId, toNodeId(10))
|
||||
const subgraphNode = fromAny<SubgraphNode, unknown>({
|
||||
id: toNodeId(9),
|
||||
graph: { isRootGraph: true },
|
||||
subgraph: {
|
||||
id: subgraphId,
|
||||
nodes: [createMockNode({ id: 10 })]
|
||||
}
|
||||
})
|
||||
|
||||
store.setNodePreviewsByLocatorId(parentLocatorId, ['blob:parent'])
|
||||
store.setNodePreviewsByLocatorId(childLocatorId, ['blob:child'])
|
||||
store.revokeSubgraphPreviews(subgraphNode)
|
||||
|
||||
expect(store.nodePreviewImages[parentLocatorId]).toBeUndefined()
|
||||
expect(store.nodePreviewImages[childLocatorId]).toBeUndefined()
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:parent')
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:child')
|
||||
})
|
||||
|
||||
it('uses the parent graph id for non-root subgraph preview revocation', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const graphId = '22222222-2222-2222-2222-222222222222'
|
||||
const subgraphId = '33333333-3333-3333-3333-333333333333'
|
||||
const parentLocatorId = createNodeLocatorId(graphId, toNodeId(9))
|
||||
const subgraphNode = fromAny<SubgraphNode, unknown>({
|
||||
id: toNodeId(9),
|
||||
graph: { id: graphId, isRootGraph: false },
|
||||
subgraph: { id: subgraphId, nodes: [] }
|
||||
})
|
||||
|
||||
store.setNodePreviewsByLocatorId(parentLocatorId, ['blob:parent'])
|
||||
store.revokeSubgraphPreviews(subgraphNode)
|
||||
|
||||
expect(store.nodePreviewImages[parentLocatorId]).toBeUndefined()
|
||||
})
|
||||
|
||||
it('leaves previews alone when a subgraph node has no parent graph', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const locatorId = createNodeLocatorId(null, toNodeId(9))
|
||||
const subgraphNode = fromAny<SubgraphNode, unknown>({
|
||||
graph: undefined,
|
||||
subgraph: { nodes: [] }
|
||||
})
|
||||
|
||||
store.setNodePreviewsByLocatorId(locatorId, ['blob:parent'])
|
||||
store.revokeSubgraphPreviews(subgraphNode)
|
||||
|
||||
expect(store.nodePreviewImages[locatorId]).toEqual(['blob:parent'])
|
||||
})
|
||||
|
||||
it('removes outputs and previews for a node id', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
createMockOutputs([{ filename: 'result.png' }])
|
||||
)
|
||||
store.setNodePreviewsByExecutionId(executionId, ['blob:preview'])
|
||||
|
||||
expect(store.removeNodeOutputs(toNodeId(5))).toBe(true)
|
||||
expect(store.nodeOutputs[executionId]).toBeUndefined()
|
||||
expect(store.nodePreviewImages[executionId]).toBeUndefined()
|
||||
expect(mockReleaseSharedObjectUrl).toHaveBeenCalledWith('blob:preview')
|
||||
})
|
||||
|
||||
it('returns false when removing outputs for a node with no outputs', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
expect(store.removeNodeOutputsForNode(createMockNode({ id: 9 }))).toBe(
|
||||
false
|
||||
)
|
||||
})
|
||||
|
||||
it('returns false when a node id cannot resolve to a locator', () => {
|
||||
const store = useNodeOutputStore()
|
||||
mockNodeIdToNodeLocatorId.mockReturnValueOnce(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
expect(store.removeNodeOutputs(toNodeId(9))).toBe(false)
|
||||
})
|
||||
|
||||
it('removes preview state even when preview entries are not iterable', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const executionId = createNodeExecutionId([toNodeId(5)])
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
executionId,
|
||||
createMockOutputs([{ filename: 'result.png' }])
|
||||
)
|
||||
app.nodePreviewImages[executionId] = fromAny({})
|
||||
store.nodePreviewImages[executionId] = fromAny({})
|
||||
|
||||
expect(store.removeNodeOutputs(toNodeId(5))).toBe(true)
|
||||
expect(store.nodePreviewImages[executionId]).toBeUndefined()
|
||||
expect(mockReleaseSharedObjectUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore output refresh', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.clearAllMocks()
|
||||
app.nodeOutputs = {}
|
||||
app.nodePreviewImages = {}
|
||||
})
|
||||
|
||||
it('updates stored output images from legacy node images', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({
|
||||
id: 5,
|
||||
images: [{ filename: 'new.png', type: 'temp' }]
|
||||
})
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
createNodeExecutionId([toNodeId(5)]),
|
||||
createMockOutputs([{ filename: 'old.png', type: 'temp' }])
|
||||
)
|
||||
store.updateNodeImages(node)
|
||||
|
||||
expect(store.nodeOutputs['5']?.images).toEqual([
|
||||
{ filename: 'new.png', type: 'temp' }
|
||||
])
|
||||
})
|
||||
|
||||
it('ignores legacy image updates when the node has no images', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
store.updateNodeImages(createMockNode({ id: 5 }))
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores legacy image updates when no locator exists', () => {
|
||||
const store = useNodeOutputStore()
|
||||
mockNodeIdToNodeLocatorId.mockReturnValueOnce(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
store.updateNodeImages(
|
||||
createMockNode({ id: 5, images: [{ filename: 'new.png' }] })
|
||||
)
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('ignores legacy image updates when no output exists', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
store.updateNodeImages(
|
||||
createMockNode({ id: 5, images: [{ filename: 'new.png' }] })
|
||||
)
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('copies app outputs into reactive state during refresh', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const node = createMockNode({ id: 5 })
|
||||
const output = createMockOutputs([{ filename: 'result.png' }])
|
||||
app.nodeOutputs['5'] = output
|
||||
|
||||
store.refreshNodeOutputs(node)
|
||||
|
||||
expect(store.nodeOutputs['5']).toEqual(output)
|
||||
expect(store.nodeOutputs['5']).not.toBe(output)
|
||||
})
|
||||
|
||||
it('does not refresh when a node has no locator', () => {
|
||||
const store = useNodeOutputStore()
|
||||
mockNodeToNodeLocatorId.mockReturnValueOnce(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
store.refreshNodeOutputs(createMockNode({ id: 5 }))
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('does not refresh when app has no output for the node', () => {
|
||||
const store = useNodeOutputStore()
|
||||
|
||||
store.refreshNodeOutputs(createMockNode({ id: 5 }))
|
||||
|
||||
expect(store.nodeOutputs).toEqual({})
|
||||
})
|
||||
|
||||
it('keeps unresolved restore output ids as their original ids', () => {
|
||||
const store = useNodeOutputStore()
|
||||
const output = createMockOutputs([{ filename: 'saved.png' }])
|
||||
mockExecutionIdToNodeLocatorId.mockReturnValueOnce(
|
||||
fromAny<NodeLocatorId, undefined>(undefined)
|
||||
)
|
||||
|
||||
store.restoreOutputs({ missing: output })
|
||||
|
||||
expect(store.nodeOutputs.missing).toEqual(output)
|
||||
})
|
||||
})
|
||||
|
||||
describe('nodeOutputStore syncLegacyNodeImgs', () => {
|
||||
@@ -894,4 +1429,20 @@ describe('nodeOutputStore syncLegacyNodeImgs', () => {
|
||||
expect(mockNode.imgs).toEqual([mockImg])
|
||||
expect(mockNode.imageIndex).toBe(0)
|
||||
})
|
||||
|
||||
it('copies output images onto the legacy node', () => {
|
||||
LiteGraph.vueNodesMode = true
|
||||
const store = useNodeOutputStore()
|
||||
const mockNode = createMockNode({ id: 1 })
|
||||
const mockImg = document.createElement('img')
|
||||
mockResolveNode.mockReturnValue(mockNode)
|
||||
|
||||
store.setNodeOutputsByExecutionId(
|
||||
createNodeExecutionId([toNodeId(1)]),
|
||||
createMockOutputs([{ filename: 'result.png', type: 'temp' }])
|
||||
)
|
||||
store.syncLegacyNodeImgs(toNodeId(1), mockImg)
|
||||
|
||||
expect(mockNode.images).toEqual([{ filename: 'result.png', type: 'temp' }])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -95,6 +95,22 @@ describe(usePreviewExposureStore, () => {
|
||||
|
||||
expect(store.getExposures(rootGraphA, hostA)).toEqual([])
|
||||
})
|
||||
|
||||
it('clears only the requested host when other hosts remain', () => {
|
||||
store.addExposure(rootGraphA, hostA, {
|
||||
sourceNodeId: '42',
|
||||
sourcePreviewName: 'preview'
|
||||
})
|
||||
store.addExposure(rootGraphA, hostB, {
|
||||
sourceNodeId: '43',
|
||||
sourcePreviewName: 'preview'
|
||||
})
|
||||
|
||||
store.setExposures(rootGraphA, hostA, [])
|
||||
|
||||
expect(store.getExposures(rootGraphA, hostA)).toEqual([])
|
||||
expect(store.getExposures(rootGraphA, hostB)).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('removeExposure', () => {
|
||||
@@ -122,6 +138,12 @@ describe(usePreviewExposureStore, () => {
|
||||
store.removeExposure(rootGraphA, hostA, 'does-not-exist')
|
||||
expect(store.getExposures(rootGraphA, hostA)).toEqual(before)
|
||||
})
|
||||
|
||||
it('is a no-op for an unknown host', () => {
|
||||
store.removeExposure(rootGraphA, 'missing-host', 'preview')
|
||||
|
||||
expect(store.getExposures(rootGraphA, 'missing-host')).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getExposuresAsPromotionShape', () => {
|
||||
|
||||
Reference in New Issue
Block a user