diff --git a/src/composables/graph/useErrorClearingHooks.test.ts b/src/composables/graph/useErrorClearingHooks.test.ts index e132c00575..0a4b351333 100644 --- a/src/composables/graph/useErrorClearingHooks.test.ts +++ b/src/composables/graph/useErrorClearingHooks.test.ts @@ -21,6 +21,11 @@ import { useMissingNodesErrorStore } from '@/platform/nodeReplacement/missingNod import { app } from '@/scripts/app' import { useExecutionErrorStore } from '@/stores/executionErrorStore' import { seedRequiredInputMissingNodeError } from '@/utils/__tests__/executionErrorTestUtils' +import type { MissingModelCandidate } from '@/platform/missingModel/types' + +beforeEach(() => { + vi.restoreAllMocks() +}) describe('Connection error clearing via onConnectionsChange', () => { beforeEach(() => { @@ -347,6 +352,90 @@ describe('installErrorClearingHooks lifecycle', () => { installErrorClearingHooks(graph) expect(node.onConnectionsChange).toBe(chainedAfterFirst) }) + + it('scans added-node missing models after widget values are restored', async () => { + const graph = new LGraph() + vi.spyOn(app, 'rootGraph', 'get').mockReturnValue(graph) + installErrorClearingHooks(graph) + + const node = new LGraphNode('CheckpointLoaderSimple') + node.type = 'CheckpointLoaderSimple' + const widget = node.addWidget('combo', 'ckpt_name', '', () => undefined, { + values: [] + }) + + graph.add(node) + widget.value = 'fake_model.safetensors' + + await Promise.resolve() + + expect(useMissingModelStore().missingModelCandidates).toEqual([ + expect.objectContaining({ name: 'fake_model.safetensors' }) + ]) + }) + + it('scans added-node missing models before the deferred media scan', async () => { + const graph = new LGraph() + vi.spyOn(app, 'rootGraph', 'get').mockReturnValue(graph) + const modelScan = vi + .spyOn(missingModelScan, 'scanNodeModelCandidates') + .mockImplementation((_rootGraph, node) => [ + { + nodeId: String(node.id), + nodeType: node.type, + widgetName: 'ckpt_name', + isAssetSupported: false, + name: 'fake_model.safetensors', + directory: 'checkpoints', + isMissing: true + } satisfies MissingModelCandidate + ]) + const mediaScan = vi + .spyOn(missingMediaScan, 'scanNodeMediaCandidates') + .mockReturnValue([]) + installErrorClearingHooks(graph) + + const node = new LGraphNode('CheckpointLoaderSimple') + node.type = 'CheckpointLoaderSimple' + graph.add(node) + + await Promise.resolve() + + expect(modelScan).toHaveBeenCalledOnce() + expect(useMissingModelStore().missingModelCandidates).toEqual([ + expect.objectContaining({ name: 'fake_model.safetensors' }) + ]) + expect(mediaScan).not.toHaveBeenCalled() + + await Promise.resolve() + + expect(mediaScan).toHaveBeenCalledTimes(1) + expect(modelScan.mock.invocationCallOrder[0]).toBeLessThan( + mediaScan.mock.invocationCallOrder[0] + ) + }) + + it('does not surface added-node missing media when upload state is marked between deferred scans', async () => { + const graph = new LGraph() + vi.spyOn(app, 'rootGraph', 'get').mockReturnValue(graph) + vi.spyOn(missingModelScan, 'scanNodeModelCandidates').mockReturnValue([]) + const mediaScan = vi.spyOn(missingMediaScan, 'scanNodeMediaCandidates') + installErrorClearingHooks(graph) + + const node = new LGraphNode('LoadVideo') + node.type = 'LoadVideo' + node.addWidget('combo', 'file', 'uploading.mp4', () => undefined, { + values: [] + }) + + graph.add(node) + await Promise.resolve() + node.isUploading = true + await Promise.resolve() + + expect(useMissingMediaStore().missingMediaCandidates).toBeNull() + expect(mediaScan).toHaveBeenCalledOnce() + }) }) describe('onNodeRemoved clears missing asset errors by execution ID', () => { @@ -611,7 +700,6 @@ describe('realtime scan verifies pending cloud candidates', () => { describe('realtime verification staleness guards', () => { beforeEach(() => { - vi.restoreAllMocks() setActivePinia(createTestingPinia({ stubActions: false })) vi.spyOn(app, 'isGraphReady', 'get').mockReturnValue(false) }) @@ -771,7 +859,6 @@ describe('realtime verification staleness guards', () => { describe('scan skips interior of bypassed subgraph containers', () => { beforeEach(() => { - vi.restoreAllMocks() setActivePinia(createTestingPinia({ stubActions: false })) vi.spyOn(app, 'isGraphReady', 'get').mockReturnValue(false) }) diff --git a/src/composables/graph/useErrorClearingHooks.ts b/src/composables/graph/useErrorClearingHooks.ts index 5fcd9dd129..2568643bd1 100644 --- a/src/composables/graph/useErrorClearingHooks.ts +++ b/src/composables/graph/useErrorClearingHooks.ts @@ -155,25 +155,26 @@ function isNodeInactive(mode: number): boolean { return mode === LGraphEventMode.NEVER || mode === LGraphEventMode.BYPASS } -/** Scan a single node and add confirmed missing model/media to stores. - * For subgraph containers, also scans all active interior nodes. */ -function scanAndAddNodeErrors(node: LGraphNode): void { +function scanNodeErrorTargets( + node: LGraphNode, + scanNode: (node: LGraphNode) => void +): void { if (!app.rootGraph) return if (node.isSubgraphNode?.() && node.subgraph) { for (const innerNode of collectAllNodes(node.subgraph)) { if (innerNode.isSubgraphNode?.()) continue if (isNodeInactive(innerNode.mode)) continue - scanSingleNodeErrors(innerNode) + scanNode(innerNode) } return } - scanSingleNodeErrors(node) + scanNode(node) } -function scanSingleNodeErrors(node: LGraphNode): void { - if (!app.rootGraph) return +function getActiveExecutionId(node: LGraphNode): string | null { + if (!app.rootGraph) return null // Skip when any enclosing subgraph is muted/bypassed. Callers only // verify each node's own mode; entering a bypassed subgraph (via // useGraphNodeManager replaying onNodeAdded for existing interior @@ -181,7 +182,25 @@ function scanSingleNodeErrors(node: LGraphNode): void { // execId means the node has no current graph (e.g. detached mid // lifecycle) — also skip, since we cannot verify its scope. const execId = getExecutionIdByNode(app.rootGraph, node) - if (!execId || !isAncestorPathActive(app.rootGraph, execId)) return + if (!execId || !isAncestorPathActive(app.rootGraph, execId)) return null + return execId +} + +/** Scan a single node and add confirmed missing model/media to stores. + * For subgraph containers, also scans all active interior nodes. */ +function scanAndAddNodeErrors(node: LGraphNode): void { + scanNodeErrorTargets(node, scanSingleNodeErrors) +} + +function scanSingleNodeErrors(node: LGraphNode): void { + scanSingleNodeModelsAndTypes(node) + scanSingleNodeMedia(node) +} + +function scanSingleNodeModelsAndTypes(node: LGraphNode): void { + if (!app.rootGraph) return + const execId = getActiveExecutionId(node) + if (!execId) return const modelCandidates = scanNodeModelCandidates( app.rootGraph, @@ -204,39 +223,40 @@ function scanSingleNodeErrors(node: LGraphNode): void { void verifyAndAddPendingModels(pendingModels) } + const originalType = node.last_serialization?.type ?? node.type ?? 'Unknown' + if (!(originalType in LiteGraph.registered_node_types)) { + const nodeReplacementStore = useNodeReplacementStore() + const replacement = nodeReplacementStore.getReplacementFor(originalType) + const store = useMissingNodesErrorStore() + const existing = store.missingNodesError?.nodeTypes ?? [] + store.surfaceMissingNodes([ + ...existing, + { + type: originalType, + nodeId: execId, + cnrId: getCnrIdFromNode(node), + isReplaceable: replacement !== null, + replacement: replacement ?? undefined + } + ]) + } +} + +function scanSingleNodeMedia(node: LGraphNode): void { + if (!app.rootGraph) return + if (!getActiveExecutionId(node)) return + const mediaCandidates = scanNodeMediaCandidates(app.rootGraph, node, isCloud) const confirmedMedia = mediaCandidates.filter((c) => c.isMissing === true) if (confirmedMedia.length) { useMissingMediaStore().addMissingMedia(confirmedMedia) } // Cloud media scans return pending for asset verification. OSS scans only - // return pending for generated output/temp media. + // return pending for generated output media. const pendingMedia = mediaCandidates.filter((c) => c.isMissing === undefined) if (pendingMedia.length) { void verifyAndAddPendingMedia(pendingMedia) } - - // Check for missing node type - const originalType = node.last_serialization?.type ?? node.type ?? 'Unknown' - if (!(originalType in LiteGraph.registered_node_types)) { - const execId = getExecutionIdByNode(app.rootGraph, node) - if (execId) { - const nodeReplacementStore = useNodeReplacementStore() - const replacement = nodeReplacementStore.getReplacementFor(originalType) - const store = useMissingNodesErrorStore() - const existing = store.missingNodesError?.nodeTypes ?? [] - store.surfaceMissingNodes([ - ...existing, - { - type: originalType, - nodeId: execId, - cnrId: getCnrIdFromNode(node), - isReplaceable: replacement !== null, - replacement: replacement ?? undefined - } - ]) - } - } } /** @@ -293,10 +313,23 @@ async function verifyAndAddPendingMedia( } } -function scanAddedNode(node: LGraphNode): void { +function scanAddedNode( + node: LGraphNode, + scanNode: (node: LGraphNode) => void +): void { if (!app.rootGraph || ChangeTracker.isLoadingGraph) return if (isNodeInactive(node.mode)) return - scanAndAddNodeErrors(node) + scanNodeErrorTargets(node, scanNode) +} + +function scheduleAddedNodeScan(node: LGraphNode): void { + queueMicrotask(() => { + scanAddedNode(node, scanSingleNodeModelsAndTypes) + // Paste/drop upload handlers run immediately after graph.add and must set + // node.isUploading synchronously before their first await. This second + // microtask lets that upload state settle before media widgets are scanned. + queueMicrotask(() => scanAddedNode(node, scanSingleNodeMedia)) + }) } function handleNodeModeChange( @@ -368,10 +401,12 @@ export function installErrorClearingHooks(graph: LGraph): () => void { // Scan pasted/duplicated nodes for missing models/media. // Skip during loadGraphData (undo/redo/tab switch) — those are // handled by the full pipeline or cache restore. - // Deferred to microtask because onNodeAdded fires before - // node.configure() restores widget values. + // Model and node scans use the original one-microtask deferral so pasted + // missing-model errors appear before selection-scoped tabs recalculate. + // Media gets one extra microtask so drag/drop upload handlers can mark + // transient upload state before media detection reads the widget value. if (!ChangeTracker.isLoadingGraph) { - queueMicrotask(() => scanAddedNode(node)) + scheduleAddedNodeScan(node) } originalOnNodeAdded?.call(this, node) diff --git a/src/composables/node/useNodeImageUpload.test.ts b/src/composables/node/useNodeImageUpload.test.ts index b03d80237b..1e088bfdce 100644 --- a/src/composables/node/useNodeImageUpload.test.ts +++ b/src/composables/node/useNodeImageUpload.test.ts @@ -53,8 +53,8 @@ function createMockNode(): LGraphNode { }) } -function createFile(name = 'test.png'): File { - return new File(['data'], name, { type: 'image/png' }) +function createFile(name = 'test.png', type = 'image/png'): File { + return new File(['data'], name, { type }) } function successResponse(name: string, subfolder?: string) { @@ -94,15 +94,21 @@ describe('useNodeImageUpload', () => { }) }) - it('sets isUploading true during upload and false after', async () => { - mockFetchApi.mockResolvedValueOnce(successResponse('test.png')) + it.for([ + { mediaType: 'image', filename: 'test.png', mimeType: 'image/png' }, + { mediaType: 'video', filename: 'clip.mp4', mimeType: 'video/mp4' } + ])( + 'sets isUploading true during $mediaType upload and false after', + async ({ filename, mimeType }) => { + mockFetchApi.mockResolvedValueOnce(successResponse(filename)) - const promise = capturedDragOnDrop([createFile()]) - expect(node.isUploading).toBe(true) + const promise = capturedDragOnDrop([createFile(filename, mimeType)]) + expect(node.isUploading).toBe(true) - await promise - expect(node.isUploading).toBe(false) - }) + await promise + expect(node.isUploading).toBe(false) + } + ) it('clears node.imgs on upload start', async () => { mockFetchApi.mockResolvedValueOnce(successResponse('test.png')) diff --git a/src/extensions/core/uploadAudio.test.ts b/src/extensions/core/uploadAudio.test.ts new file mode 100644 index 0000000000..d123ae1ea3 --- /dev/null +++ b/src/extensions/core/uploadAudio.test.ts @@ -0,0 +1,241 @@ +import { fromAny } from '@total-typescript/shoehorn' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import type { LGraphNode } from '@/lib/litegraph/src/litegraph' +import type { ComfyExtension } from '@/types/comfy' + +const { mockAddAlert, mockApiURL, mockFetchApi, mockRegisterExtension } = + vi.hoisted(() => ({ + mockAddAlert: vi.fn(), + mockApiURL: vi.fn((url: string) => `api:${url}`), + mockFetchApi: vi.fn(), + mockRegisterExtension: vi.fn() + })) + +let capturedDragDrop: ((files: File[]) => Promise) | undefined +let capturedFileSelect: + | ((files: File[]) => Promise) + | undefined +let capturedPaste: ((files: File[]) => Promise) | undefined + +type AudioUploadWidget = (node: LGraphNode, inputName: string) => unknown + +vi.mock('extendable-media-recorder', () => ({ + MediaRecorder: class MockMediaRecorder {} +})) + +vi.mock('@/composables/node/useNodeDragAndDrop', () => ({ + useNodeDragAndDrop: ( + _node: LGraphNode, + options: { onDrop: typeof capturedDragDrop } + ) => { + capturedDragDrop = options.onDrop + } +})) + +vi.mock('@/composables/node/useNodeFileInput', () => ({ + useNodeFileInput: ( + _node: LGraphNode, + options: { onSelect: typeof capturedFileSelect } + ) => { + capturedFileSelect = options.onSelect + return { openFileSelection: vi.fn() } + } +})) + +vi.mock('@/composables/node/useNodePaste', () => ({ + useNodePaste: ( + _node: LGraphNode, + options: { onPaste: typeof capturedPaste } + ) => { + capturedPaste = options.onPaste + } +})) + +vi.mock('@/i18n', () => ({ + t: (key: string) => key +})) + +vi.mock('@/platform/updates/common/toastStore', () => ({ + useToastStore: () => ({ addAlert: mockAddAlert }) +})) + +vi.mock('@/renderer/extensions/vueNodes/widgets/utils/audioUtils', () => ({ + getResourceURL: (subfolder = '', filename = '', type = 'input') => + `/view?filename=${filename}&subfolder=${subfolder}&type=${type}`, + splitFilePath: (path: string) => ['', path, 'input'] +})) + +vi.mock('@/scripts/api', () => ({ + api: { + apiURL: mockApiURL, + fetchApi: mockFetchApi + } +})) + +vi.mock('@/scripts/app', () => ({ + app: { + registerExtension: mockRegisterExtension, + rootGraph: { id: 'root' } + } +})) + +vi.mock('@/stores/widgetValueStore', () => ({ + useWidgetValueStore: () => ({ + getWidget: vi.fn() + }) +})) + +vi.mock('@/utils/graphTraversalUtil', () => ({ + getNodeByLocatorId: vi.fn() +})) + +vi.mock('@/services/audioService', () => ({ + useAudioService: () => ({}) +})) + +function createFile(name = 'clip.mp3'): File { + return new File(['audio'], name, { type: 'audio/mpeg' }) +} + +function successResponse(name: string, subfolder?: string) { + return { + status: 200, + json: () => Promise.resolve({ name, subfolder }) + } +} + +function failResponse(status = 500) { + return { + status, + statusText: 'Server Error' + } +} + +function createAudioNode() { + const audioWidget = { + name: 'audio', + value: 'previous.mp3', + options: { values: ['previous.mp3'] }, + callback: vi.fn() + } + const audioUIWidget = { + name: 'audioUI', + element: document.createElement('audio'), + value: '', + callback: vi.fn() + } + const uploadWidget = { label: '', serialize: true, canvasOnly: false } + const node = fromAny({ + widgets: [audioWidget, audioUIWidget], + isUploading: false, + graph: { setDirtyCanvas: vi.fn() }, + addWidget: vi.fn(() => uploadWidget) + }) + + return { audioUIWidget, audioWidget, node, uploadWidget } +} + +async function loadAudioUploadWidget() { + vi.resetModules() + mockRegisterExtension.mockClear() + await import('./uploadAudio') + const extension = mockRegisterExtension.mock.calls + .map(([extension]) => extension as ComfyExtension) + .find((extension) => extension.name === 'Comfy.UploadAudio') + if (!extension) + throw new Error('Comfy.UploadAudio extension was not registered') + const widgets = await extension.getCustomWidgets!(fromAny({})) + return (widgets as Record).AUDIOUPLOAD +} + +describe('Comfy.UploadAudio AUDIOUPLOAD widget', () => { + beforeEach(() => { + vi.clearAllMocks() + capturedDragDrop = undefined + capturedFileSelect = undefined + capturedPaste = undefined + }) + + it('sets isUploading while upload is in progress and clears it after success', async () => { + const AUDIOUPLOAD = await loadAudioUploadWidget() + const { audioWidget, node } = createAudioNode() + AUDIOUPLOAD(node, 'upload') + + let resolveUpload: (response: ReturnType) => void + mockFetchApi.mockReturnValueOnce( + new Promise((resolve) => { + resolveUpload = resolve + }) + ) + + const upload = capturedDragDrop!([createFile()]) + + expect(node.isUploading).toBe(true) + expect(audioWidget.value).toBe('clip.mp3') + + resolveUpload!(successResponse('uploaded.mp3', 'pasted')) + await upload + + expect(node.isUploading).toBe(false) + expect(audioWidget.value).toBe('pasted/uploaded.mp3') + expect(audioWidget.options.values).toContain('pasted/uploaded.mp3') + expect(node.graph?.setDirtyCanvas).toHaveBeenCalledWith(true) + }) + + it('rejects concurrent audio uploads without starting another request', async () => { + const AUDIOUPLOAD = await loadAudioUploadWidget() + const { node } = createAudioNode() + AUDIOUPLOAD(node, 'upload') + node.isUploading = true + + const result = await capturedDragDrop!([createFile()]) + + expect(result).toEqual([]) + expect(mockAddAlert).toHaveBeenCalledWith('g.uploadAlreadyInProgress') + expect(mockFetchApi).not.toHaveBeenCalled() + }) + + it('rolls back the widget value and clears isUploading when upload fails', async () => { + const AUDIOUPLOAD = await loadAudioUploadWidget() + const { audioWidget, node } = createAudioNode() + AUDIOUPLOAD(node, 'upload') + mockFetchApi.mockResolvedValueOnce(failResponse()) + + await capturedPaste!([createFile()]) + + expect(node.isUploading).toBe(false) + expect(audioWidget.value).toBe('previous.mp3') + expect(mockAddAlert).toHaveBeenCalledWith('500 - Server Error') + expect(node.graph?.setDirtyCanvas).toHaveBeenCalledWith(true) + }) + + it('rolls back the widget value and clears isUploading when upload throws synchronously', async () => { + const AUDIOUPLOAD = await loadAudioUploadWidget() + const { audioWidget, node } = createAudioNode() + AUDIOUPLOAD(node, 'upload') + const error = new Error('Upload failed before request promise') + mockFetchApi.mockImplementationOnce(() => { + throw error + }) + + await capturedDragDrop!([createFile()]) + + expect(node.isUploading).toBe(false) + expect(audioWidget.value).toBe('previous.mp3') + expect(mockAddAlert).toHaveBeenCalledWith(error) + expect(node.graph?.setDirtyCanvas).toHaveBeenCalledWith(true) + }) + + it('returns early when no files are provided', async () => { + const AUDIOUPLOAD = await loadAudioUploadWidget() + const { node } = createAudioNode() + AUDIOUPLOAD(node, 'upload') + + const result = await capturedFileSelect!([]) + + expect(result).toEqual([]) + expect(node.isUploading).toBe(false) + expect(mockFetchApi).not.toHaveBeenCalled() + }) +}) diff --git a/src/extensions/core/uploadAudio.ts b/src/extensions/core/uploadAudio.ts index 8b6f65ab26..34229d8c60 100644 --- a/src/extensions/core/uploadAudio.ts +++ b/src/extensions/core/uploadAudio.ts @@ -234,9 +234,17 @@ app.registerExtension({ } const handleUpload = async (files: File[]) => { - if (files?.length) { - const previousValue = audioWidget.value - audioWidget.value = files[0].name + if (!files?.length) return files + + if (node.isUploading) { + useToastStore().addAlert(t('g.uploadAlreadyInProgress')) + return [] + } + + node.isUploading = true + const previousValue = audioWidget.value + audioWidget.value = files[0].name + try { const success = await uploadFile( audioWidget, audioUIWidget, @@ -246,6 +254,9 @@ app.registerExtension({ if (!success) { audioWidget.value = previousValue } + } finally { + node.isUploading = false + node.graph?.setDirtyCanvas(true) } return files } diff --git a/src/platform/missingMedia/missingMediaScan.test.ts b/src/platform/missingMedia/missingMediaScan.test.ts index 78073743bc..edb906dd6c 100644 --- a/src/platform/missingMedia/missingMediaScan.test.ts +++ b/src/platform/missingMedia/missingMediaScan.test.ts @@ -204,6 +204,48 @@ describe('scanNodeMediaCandidates', () => { expect(result).toEqual([]) }) + it.for([false, true])( + 'returns empty while a media upload is pending on the node (isCloud: %s)', + (isCloud) => { + const graph = makeGraph([]) + const node = makeMediaNode( + 1, + 'LoadVideo', + [makeMediaCombo('file', 'clip.mp4', [])], + 0 + ) + node.isUploading = true + + const result = scanNodeMediaCandidates(graph, node, isCloud) + + expect(result).toEqual([]) + } + ) + + it('detects missing media again after upload state clears', () => { + const graph = makeGraph([]) + const node = makeMediaNode( + 1, + 'LoadVideo', + [makeMediaCombo('file', 'clip.mp4', [])], + 0 + ) + + node.isUploading = true + expect(scanNodeMediaCandidates(graph, node, false)).toEqual([]) + + node.isUploading = false + expect(scanNodeMediaCandidates(graph, node, false)).toEqual([ + expect.objectContaining({ + nodeType: 'LoadVideo', + widgetName: 'file', + mediaType: 'video', + name: 'clip.mp4', + isMissing: true + }) + ]) + }) + it.each([ { nodeType: 'LoadImage', diff --git a/src/platform/missingMedia/missingMediaScan.ts b/src/platform/missingMedia/missingMediaScan.ts index afbd3bcf27..9adb179f82 100644 --- a/src/platform/missingMedia/missingMediaScan.ts +++ b/src/platform/missingMedia/missingMediaScan.ts @@ -87,6 +87,7 @@ export function scanNodeMediaCandidates( const mediaInfo = MEDIA_NODE_WIDGETS[node.type] if (!mediaInfo) return [] + if (node.isUploading) return [] const executionId = getExecutionIdByNode(rootGraph, node) if (!executionId) return []