diff --git a/src/composables/usePaste.test.ts b/src/composables/usePaste.test.ts index 9842ba6648..40c0f826a0 100644 --- a/src/composables/usePaste.test.ts +++ b/src/composables/usePaste.test.ts @@ -5,12 +5,13 @@ import type { LGraphGroup, LGraphNode } from '@/lib/litegraph/src/litegraph' -import { LiteGraph } from '@/lib/litegraph/src/litegraph' import { app } from '@/scripts/app' import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils' -import { createNode, isImageNode } from '@/utils/litegraphUtil' +import { createNode, isAudioNode, isImageNode } from '@/utils/litegraphUtil' import { cloneDataTransfer, + pasteAudioNode, + pasteAudioNodes, pasteImageNode, pasteImageNodes, usePaste @@ -203,6 +204,102 @@ describe('pasteImageNodes', () => { }) }) +describe('pasteAudioNode', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should create new LoadAudio node when no audio node provided', async () => { + const mockNode = createMockNode() + vi.mocked(createNode).mockResolvedValue(mockNode) + + const file = createAudioFile() + const dataTransfer = createDataTransfer([file]) + + await pasteAudioNode(mockCanvas, dataTransfer.items) + + expect(createNode).toHaveBeenCalledWith(mockCanvas, 'LoadAudio') + expect(mockNode.pasteFile).toHaveBeenCalledWith(file) + }) + + it('should use existing audio node when provided', async () => { + const mockNode = createMockNode() + const file = createAudioFile() + const dataTransfer = createDataTransfer([file]) + + await pasteAudioNode(mockCanvas, dataTransfer.items, mockNode) + + expect(createNode).not.toHaveBeenCalled() + expect(mockNode.pasteFile).toHaveBeenCalledWith(file) + }) + + it('should filter non-audio items', async () => { + const mockNode = createMockNode() + const audioFile = createAudioFile() + const textFile = new File([''], 'test.txt', { type: 'text/plain' }) + const dataTransfer = createDataTransfer([textFile, audioFile]) + + await pasteAudioNode(mockCanvas, dataTransfer.items, mockNode) + + expect(mockNode.pasteFile).toHaveBeenCalledWith(audioFile) + expect(mockNode.pasteFiles).toHaveBeenCalledWith([audioFile]) + }) + + it('should do nothing when no audio files present', async () => { + const mockNode = createMockNode() + const dataTransfer = createDataTransfer() + + await pasteAudioNode(mockCanvas, dataTransfer.items, mockNode) + + expect(mockNode.pasteFile).not.toHaveBeenCalled() + expect(mockNode.pasteFiles).not.toHaveBeenCalled() + }) +}) + +describe('pasteAudioNodes', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should create multiple nodes for multiple audio files', async () => { + const mockNode1 = createMockNode() + const mockNode2 = createMockNode() + vi.mocked(createNode) + .mockResolvedValueOnce(mockNode1) + .mockResolvedValueOnce(mockNode2) + + const file1 = createAudioFile('file1.mp3') + const file2 = createAudioFile('file2.wav', 'audio/wav') + + const result = await pasteAudioNodes(mockCanvas, [file1, file2]) + + expect(createNode).toHaveBeenCalledTimes(2) + expect(createNode).toHaveBeenNthCalledWith(1, mockCanvas, 'LoadAudio') + expect(createNode).toHaveBeenNthCalledWith(2, mockCanvas, 'LoadAudio') + expect(mockNode1.pasteFile).toHaveBeenCalledWith(file1) + expect(mockNode2.pasteFile).toHaveBeenCalledWith(file2) + expect(result).toEqual([mockNode1, mockNode2]) + }) + + it('should handle empty file list', async () => { + const result = await pasteAudioNodes(mockCanvas, []) + + expect(createNode).not.toHaveBeenCalled() + expect(result).toEqual([]) + }) + + it('should handle single audio file', async () => { + const mockNode = createMockNode() + vi.mocked(createNode).mockResolvedValue(mockNode) + + const file = createAudioFile() + const result = await pasteAudioNodes(mockCanvas, [file]) + + expect(createNode).toHaveBeenCalledTimes(1) + expect(result).toEqual([mockNode]) + }) +}) + describe('usePaste', () => { beforeEach(() => { vi.clearAllMocks() @@ -230,9 +327,9 @@ describe('usePaste', () => { }) }) - it('should handle audio paste', async () => { + it('should handle audio paste using createNode helper', async () => { const mockNode = createMockNode() - vi.mocked(LiteGraph.createNode).mockReturnValue(mockNode) + vi.mocked(createNode).mockResolvedValue(mockNode) usePaste() @@ -242,7 +339,29 @@ describe('usePaste', () => { document.dispatchEvent(event) await vi.waitFor(() => { - expect(LiteGraph.createNode).toHaveBeenCalledWith('LoadAudio') + expect(createNode).toHaveBeenCalledWith(mockCanvas, 'LoadAudio') + expect(mockNode.pasteFile).toHaveBeenCalledWith(file) + }) + }) + + it('should paste audio onto selected LoadAudio node', async () => { + const mockNode = createMockLGraphNode({ + is_selected: true, + pasteFile: vi.fn(), + pasteFiles: vi.fn() + }) + mockCanvas.current_node = mockNode + vi.mocked(isAudioNode).mockReturnValue(true) + + usePaste() + + const file = createAudioFile() + const dataTransfer = createDataTransfer([file]) + const event = new ClipboardEvent('paste', { clipboardData: dataTransfer }) + document.dispatchEvent(event) + + await vi.waitFor(() => { + expect(createNode).not.toHaveBeenCalled() expect(mockNode.pasteFile).toHaveBeenCalledWith(file) }) }) @@ -273,7 +392,7 @@ describe('usePaste', () => { const event = new ClipboardEvent('paste', { clipboardData: dataTransfer }) document.dispatchEvent(event) - expect(LiteGraph.createNode).not.toHaveBeenCalled() + expect(createNode).not.toHaveBeenCalled() }) it('should use existing image node when selected', () => { diff --git a/src/composables/usePaste.ts b/src/composables/usePaste.ts index ccf478e246..1ca7521d8d 100644 --- a/src/composables/usePaste.ts +++ b/src/composables/usePaste.ts @@ -1,7 +1,6 @@ import { useEventListener } from '@vueuse/core' import type { LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph' -import { LiteGraph } from '@/lib/litegraph/src/litegraph' import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { app } from '@/scripts/app' @@ -113,6 +112,37 @@ export async function pasteImageNodes( return nodes } +export async function pasteAudioNode( + canvas: LGraphCanvas, + items: DataTransferItemList, + audioNode: LGraphNode | null = null +): Promise { + if (!audioNode) { + audioNode = await createNode(canvas, 'LoadAudio') + } + pasteItemsOnNode(items, audioNode, 'audio') + return audioNode +} + +export async function pasteAudioNodes( + canvas: LGraphCanvas, + fileList: File[] +): Promise { + const nodes: LGraphNode[] = [] + + for (const file of fileList) { + const transfer = new DataTransfer() + transfer.items.add(file) + const node = await pasteAudioNode(canvas, transfer.items) + + if (node) { + nodes.push(node) + } + } + + return nodes +} + /** * Adds a handler on paste that extracts and loads images or workflows from pasted JSON data */ @@ -132,7 +162,6 @@ export const usePaste = () => { const { canvas } = canvasStore if (!canvas) return - const { graph } = canvas let data: DataTransfer | string | null = e.clipboardData if (!data) throw new Error('No clipboard data on clipboard event') data = cloneDataTransfer(data) @@ -146,7 +175,9 @@ export const usePaste = () => { const isVideoNodeSelected = isNodeSelected && isVideoNode(currentNode) const isAudioNodeSelected = isNodeSelected && isAudioNode(currentNode) - let audioNode: LGraphNode | null = isAudioNodeSelected ? currentNode : null + const audioNode: LGraphNode | null = isAudioNodeSelected + ? currentNode + : null const imageNode: LGraphNode | null = isImageNodeSelected ? currentNode : null @@ -168,16 +199,7 @@ export const usePaste = () => { return } } else if (item.type.startsWith('audio/')) { - if (!audioNode) { - // No audio node selected: add a new one - const newNode = LiteGraph.createNode('LoadAudio') - if (newNode) { - newNode.pos = [canvas.graph_mouse[0], canvas.graph_mouse[1]] - audioNode = graph?.add(newNode) ?? null - } - graph?.change() - } - pasteItemsOnNode(items, audioNode, 'audio') + await pasteAudioNode(canvas as LGraphCanvas, items, audioNode) return } } diff --git a/src/scripts/app.test.ts b/src/scripts/app.test.ts index 7fc8ae3193..ad090269e7 100644 --- a/src/scripts/app.test.ts +++ b/src/scripts/app.test.ts @@ -100,21 +100,24 @@ describe('ComfyApp', () => { expect(mockNode2.connect).toHaveBeenCalledWith(0, mockBatchNode, 1) }) - it('should not proceed if batch node creation fails', async () => { + it('should select single image node without batch node', async () => { const mockNode1 = createMockNode({ id: 1 }) vi.mocked(pasteImageNodes).mockResolvedValue([mockNode1]) - vi.mocked(createNode).mockResolvedValue(null) const file = createTestFile('test.png', 'image/png') await app.handleFileList([file]) - expect(mockCanvas.selectItems).not.toHaveBeenCalled() + expect(createNode).not.toHaveBeenCalled() + expect(mockCanvas.selectItems).toHaveBeenCalledWith([mockNode1]) expect(mockNode1.connect).not.toHaveBeenCalled() }) it('should handle empty file list', async () => { - await expect(app.handleFileList([])).rejects.toThrow() + await app.handleFileList([]) + + expect(pasteImageNodes).not.toHaveBeenCalled() + expect(createNode).not.toHaveBeenCalled() }) it('should not process unsupported file types', async () => { diff --git a/src/scripts/app.ts b/src/scripts/app.ts index f2ac4c9695..476b409207 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -111,9 +111,18 @@ import { ComfyAppMenu } from './ui/menu/index' import { clone } from './utils' import { type ComfyWidgetConstructor } from './widgets' import { ensureCorrectLayoutScale } from '@/renderer/extensions/vueNodes/layout/ensureCorrectLayoutScale' -import { extractFilesFromDragEvent, hasImageType } from '@/utils/eventUtils' +import { + extractFilesFromDragEvent, + hasAudioType, + hasImageType +} from '@/utils/eventUtils' import { getWorkflowDataFromFile } from '@/scripts/metadata/parser' -import { pasteImageNode, pasteImageNodes } from '@/composables/usePaste' +import { + pasteAudioNode, + pasteAudioNodes, + pasteImageNode, + pasteImageNodes +} from '@/composables/usePaste' export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview' @@ -560,8 +569,24 @@ export class ComfyApp { const workspace = useWorkspaceStore() try { workspace.spinner = true - if (files.length > 1 && files.every(hasImageType)) { - await this.handleFileList(files) + const imageFiles = files.filter(hasImageType) + const audioFiles = files.filter(hasAudioType) + const totalMedia = imageFiles.length + audioFiles.length + const hasMultipleMedia = totalMedia > 1 + + if (hasMultipleMedia) { + if (imageFiles.length > 0) { + await this.handleFileList(imageFiles) + } + if (audioFiles.length > 0) { + await this.handleAudioFileList(audioFiles) + } + const handled = new Set([...imageFiles, ...audioFiles]) + for (const file of files.filter((f) => !handled.has(f))) { + await this.handleFile(file, 'file_drop', { + deferWarnings: true + }) + } } else { for (const file of files) { await this.handleFile(file, 'file_drop', { @@ -1562,6 +1587,12 @@ export class ComfyApp { const imageNode = await createNode(this.canvas, 'LoadImage') await pasteImageNode(this.canvas, transfer.items, imageNode) return + } else if (file.type.startsWith('audio')) { + const transfer = new DataTransfer() + transfer.items.add(file) + const audioNode = await createNode(this.canvas, 'LoadAudio') + await pasteAudioNode(this.canvas, transfer.items, audioNode) + return } this.showErrorOnFileLoad(file) @@ -1643,25 +1674,55 @@ export class ComfyApp { * @param {FileList} fileList */ async handleFileList(fileList: File[]) { - if (fileList[0].type.startsWith('image')) { - const imageNodes = await pasteImageNodes(this.canvas, fileList) + if (fileList.length === 0) return + if (!fileList[0].type.startsWith('image')) return + + const imageNodes = await pasteImageNodes(this.canvas, fileList) + if (imageNodes.length === 0) return + + if (imageNodes.length > 1) { const batchImagesNode = await createNode(this.canvas, 'BatchImagesNode') if (!batchImagesNode) return this.positionBatchNodes(imageNodes, batchImagesNode) this.canvas.selectItems([...imageNodes, batchImagesNode]) - Array.from(imageNodes).forEach((imageNode, index) => { + imageNodes.forEach((imageNode, index) => { imageNode.connect(0, batchImagesNode, index) }) + } else { + this.canvas.selectItems(imageNodes) } } + async handleAudioFileList(fileList: File[]) { + const audioNodes = await pasteAudioNodes(this.canvas, fileList) + if (audioNodes.length === 0) return + + this.positionNodes(audioNodes) + this.canvas.selectItems(audioNodes) + } + /** * Positions batched nodes in drag and drop * @param nodes * @param batchNode */ + positionNodes(nodes: LGraphNode[]): void { + if (nodes.length <= 1) return + + const [x, y] = nodes[0].getBounding() + const nodeHeight = 150 + + nodes.forEach((node, index) => { + if (index > 0) { + node.pos = [x, y + nodeHeight * index + 25 * (index + 1)] + } + }) + + this.canvas.graph?.change() + } + positionBatchNodes(nodes: LGraphNode[], batchNode: LGraphNode): void { const [x, y, width] = nodes[0].getBounding() batchNode.pos = [x + width + 100, y + 30] diff --git a/src/utils/eventUtils.test.ts b/src/utils/eventUtils.test.ts new file mode 100644 index 0000000000..edfa28e570 --- /dev/null +++ b/src/utils/eventUtils.test.ts @@ -0,0 +1,26 @@ +import { describe, expect, it } from 'vitest' +import { hasAudioType, hasImageType } from './eventUtils' + +describe('hasImageType', () => { + it('should return true for image types', () => { + expect(hasImageType({ type: 'image/png' } as File)).toBe(true) + expect(hasImageType({ type: 'image/jpeg' } as File)).toBe(true) + }) + + it('should return false for non-image types', () => { + expect(hasImageType({ type: 'audio/mpeg' } as File)).toBe(false) + expect(hasImageType({ type: 'video/mp4' } as File)).toBe(false) + }) +}) + +describe('hasAudioType', () => { + it('should return true for audio types', () => { + expect(hasAudioType({ type: 'audio/mpeg' } as File)).toBe(true) + expect(hasAudioType({ type: 'audio/wav' } as File)).toBe(true) + }) + + it('should return false for non-audio types', () => { + expect(hasAudioType({ type: 'image/png' } as File)).toBe(false) + expect(hasAudioType({ type: 'video/mp4' } as File)).toBe(false) + }) +}) diff --git a/src/utils/eventUtils.ts b/src/utils/eventUtils.ts index 25382865cf..67f8c8e90f 100644 --- a/src/utils/eventUtils.ts +++ b/src/utils/eventUtils.ts @@ -28,3 +28,7 @@ export async function extractFilesFromDragEvent( export function hasImageType({ type }: File): boolean { return type.startsWith('image') } + +export function hasAudioType({ type }: File): boolean { + return type.startsWith('audio') +}