diff --git a/src/base/common/downloadUtil.ts b/src/base/common/downloadUtil.ts index 307a3e35b..4462dd1a7 100644 --- a/src/base/common/downloadUtil.ts +++ b/src/base/common/downloadUtil.ts @@ -1,29 +1,64 @@ /** * Utility functions for downloading files */ +import { isCloud } from '@/platform/distribution/types' // Constants const DEFAULT_DOWNLOAD_FILENAME = 'download.png' +/** + * Trigger a download by creating a temporary anchor element + * @param href - The URL or blob URL to download + * @param filename - The filename to suggest to the browser + */ +function triggerLinkDownload(href: string, filename: string): void { + const link = document.createElement('a') + link.href = href + link.download = filename + link.style.display = 'none' + + document.body.appendChild(link) + link.click() + document.body.removeChild(link) +} + /** * Download a file from a URL by creating a temporary anchor element * @param url - The URL of the file to download (must be a valid URL string) * @param filename - Optional filename override (will use URL filename or default if not provided) * @throws {Error} If the URL is invalid or empty */ -export const downloadFile = (url: string, filename?: string): void => { +export function downloadFile(url: string, filename?: string): void { if (!url || typeof url !== 'string' || url.trim().length === 0) { throw new Error('Invalid URL provided for download') } - const link = document.createElement('a') - link.href = url - link.download = + + const inferredFilename = filename || extractFilenameFromUrl(url) || DEFAULT_DOWNLOAD_FILENAME - // Trigger download - document.body.appendChild(link) - link.click() - document.body.removeChild(link) + if (isCloud) { + // Assets from cross-origin (e.g., GCS) cannot be downloaded this way + void downloadViaBlobFetch(url, inferredFilename).catch((error) => { + console.error('Failed to download file', error) + }) + return + } + + triggerLinkDownload(url, inferredFilename) +} + +/** + * Download a Blob by creating a temporary object URL and anchor element + * @param filename - The filename to suggest to the browser + * @param blob - The Blob to download + */ +export function downloadBlob(filename: string, blob: Blob): void { + const url = URL.createObjectURL(blob) + + triggerLinkDownload(url, filename) + + // Revoke on the next microtask to give the browser time to start the download + queueMicrotask(() => URL.revokeObjectURL(url)) } /** @@ -39,3 +74,15 @@ const extractFilenameFromUrl = (url: string): string | null => { return null } } + +const downloadViaBlobFetch = async ( + href: string, + filename: string +): Promise => { + const response = await fetch(href) + if (!response.ok) { + throw new Error(`Failed to fetch ${href}: ${response.status}`) + } + const blob = await response.blob() + downloadBlob(filename, blob) +} diff --git a/src/extensions/core/load3d/ModelExporter.ts b/src/extensions/core/load3d/ModelExporter.ts index bf677ba44..f3ff3a443 100644 --- a/src/extensions/core/load3d/ModelExporter.ts +++ b/src/extensions/core/load3d/ModelExporter.ts @@ -3,6 +3,7 @@ import { GLTFExporter } from 'three/examples/jsm/exporters/GLTFExporter' import { OBJExporter } from 'three/examples/jsm/exporters/OBJExporter' import { STLExporter } from 'three/examples/jsm/exporters/STLExporter' +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { useToastStore } from '@/platform/updates/common/toastStore' @@ -38,13 +39,7 @@ export class ModelExporter { try { const response = await fetch(url) const blob = await response.blob() - - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = desiredFilename - link.click() - - URL.revokeObjectURL(link.href) + downloadBlob(desiredFilename, blob) } catch (error) { console.error('Error downloading from URL:', error) useToastStore().addAlert(t('toastMessages.failedToDownloadFile')) @@ -152,19 +147,11 @@ export class ModelExporter { private static saveArrayBuffer(buffer: ArrayBuffer, filename: string): void { const blob = new Blob([buffer], { type: 'application/octet-stream' }) - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = filename - link.click() - URL.revokeObjectURL(link.href) + downloadBlob(filename, blob) } private static saveString(text: string, filename: string): void { const blob = new Blob([text], { type: 'text/plain' }) - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = filename - link.click() - URL.revokeObjectURL(link.href) + downloadBlob(filename, blob) } } diff --git a/src/extensions/core/load3d/RecordingManager.ts b/src/extensions/core/load3d/RecordingManager.ts index 679fa9c5d..bf75ff2e8 100644 --- a/src/extensions/core/load3d/RecordingManager.ts +++ b/src/extensions/core/load3d/RecordingManager.ts @@ -1,5 +1,7 @@ import * as THREE from 'three' +import { downloadBlob } from '@/base/common/downloadUtil' + import { type EventManagerInterface } from './interfaces' export class RecordingManager { @@ -149,17 +151,7 @@ export class RecordingManager { try { const blob = new Blob(this.recordedChunks, { type: 'video/webm' }) - - const url = URL.createObjectURL(blob) - const a = document.createElement('a') - document.body.appendChild(a) - a.style.display = 'none' - a.href = url - a.download = filename - a.click() - - window.URL.revokeObjectURL(url) - document.body.removeChild(a) + downloadBlob(filename, blob) this.eventManager.emitEvent('recordingExported', null) } catch (error) { diff --git a/src/extensions/core/nodeTemplates.ts b/src/extensions/core/nodeTemplates.ts index 5737ba54d..528f73a5f 100644 --- a/src/extensions/core/nodeTemplates.ts +++ b/src/extensions/core/nodeTemplates.ts @@ -1,3 +1,4 @@ +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { LGraphCanvas } from '@/lib/litegraph/src/litegraph' import { useToastStore } from '@/platform/updates/common/toastStore' @@ -145,18 +146,7 @@ class ManageTemplates extends ComfyDialog { const json = JSON.stringify({ templates: this.templates }, null, 2) // convert the data to a JSON string const blob = new Blob([json], { type: 'application/json' }) - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - download: 'node_templates.json', - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) + downloadBlob('node_templates.json', blob) } override show() { @@ -298,19 +288,9 @@ class ManageTemplates extends ComfyDialog { const blob = new Blob([json], { type: 'application/json' }) - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - // @ts-expect-error fixme ts strict error - download: (nameInput.value || t.name) + '.json', - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) + // @ts-expect-error fixme ts strict error + const name = (nameInput.value || t.name) + '.json' + downloadBlob(name, blob) } }), $el('button', { diff --git a/src/platform/workflow/core/services/workflowService.ts b/src/platform/workflow/core/services/workflowService.ts index c4234cf45..2eba1345b 100644 --- a/src/platform/workflow/core/services/workflowService.ts +++ b/src/platform/workflow/core/services/workflowService.ts @@ -1,5 +1,6 @@ import { toRaw } from 'vue' +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { LGraph, LGraphCanvas } from '@/lib/litegraph/src/litegraph' import type { Point, SerialisableGraph } from '@/lib/litegraph/src/litegraph' @@ -13,7 +14,6 @@ import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/w import { useWorkflowThumbnail } from '@/renderer/core/thumbnail/useWorkflowThumbnail' import { app } from '@/scripts/app' import { blankGraph, defaultGraph } from '@/scripts/defaultGraph' -import { downloadBlob } from '@/scripts/utils' import { useDialogService } from '@/services/dialogService' import { useDomWidgetStore } from '@/stores/domWidgetStore' import { useWorkspaceStore } from '@/stores/workspaceStore' diff --git a/src/scripts/utils.ts b/src/scripts/utils.ts index 59556523b..a1bcc3d89 100644 --- a/src/scripts/utils.ts +++ b/src/scripts/utils.ts @@ -51,20 +51,8 @@ export async function addStylesheet( }) } -export function downloadBlob(filename: string, blob: Blob) { - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - download: filename, - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) -} +/** @knipIgnoreUnusedButUsedByCustomNodes */ +export { downloadBlob } from '@/base/common/downloadUtil' export function uploadFile(accept: string) { return new Promise((resolve, reject) => { diff --git a/src/services/colorPaletteService.ts b/src/services/colorPaletteService.ts index bfbfeb420..ddaa6d702 100644 --- a/src/services/colorPaletteService.ts +++ b/src/services/colorPaletteService.ts @@ -1,13 +1,14 @@ import { toRaw } from 'vue' import { fromZodError } from 'zod-validation-error' +import { downloadBlob } from '@/base/common/downloadUtil' import { useErrorHandling } from '@/composables/useErrorHandling' import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph' import { useSettingStore } from '@/platform/settings/settingStore' import { paletteSchema } from '@/schemas/colorPaletteSchema' import type { Colors, Palette } from '@/schemas/colorPaletteSchema' import { app } from '@/scripts/app' -import { downloadBlob, uploadFile } from '@/scripts/utils' +import { uploadFile } from '@/scripts/utils' import { useNodeDefStore } from '@/stores/nodeDefStore' import { useColorPaletteStore } from '@/stores/workspace/colorPaletteStore' diff --git a/src/services/litegraphService.ts b/src/services/litegraphService.ts index 6c84d7db2..9bc35e0b9 100644 --- a/src/services/litegraphService.ts +++ b/src/services/litegraphService.ts @@ -1,5 +1,6 @@ import _ from 'es-toolkit/compat' +import { downloadFile } from '@/base/common/downloadUtil' import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems' import { useNodeAnimatedImage } from '@/composables/node/useNodeAnimatedImage' import { useNodeCanvasImagePreview } from '@/composables/node/useNodeCanvasImagePreview' @@ -756,18 +757,10 @@ export const useLitegraphService = () => { { content: 'Save Image', callback: () => { - const a = document.createElement('a') const url = new URL(img.src) url.searchParams.delete('preview') - a.href = url.toString() - a.setAttribute( - 'download', - // @ts-expect-error fixme ts strict error - new URLSearchParams(url.search).get('filename') - ) - document.body.append(a) - a.click() - requestAnimationFrame(() => a.remove()) + const filename = new URLSearchParams(url.search).get('filename') + downloadFile(url.toString(), filename ?? undefined) } } ) diff --git a/tests-ui/tests/base/common/downloadUtil.test.ts b/tests-ui/tests/base/common/downloadUtil.test.ts index 7231b3620..8684c8e1c 100644 --- a/tests-ui/tests/base/common/downloadUtil.test.ts +++ b/tests-ui/tests/base/common/downloadUtil.test.ts @@ -2,15 +2,38 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { downloadFile } from '@/base/common/downloadUtil' +let mockIsCloud = false + +vi.mock('@/platform/distribution/types', () => ({ + get isCloud() { + return mockIsCloud + } +})) + +// Global stubs +const createObjectURLSpy = vi + .spyOn(URL, 'createObjectURL') + .mockReturnValue('blob:mock-url') +const revokeObjectURLSpy = vi + .spyOn(URL, 'revokeObjectURL') + .mockImplementation(() => {}) + describe('downloadUtil', () => { let mockLink: HTMLAnchorElement + let fetchMock: ReturnType beforeEach(() => { + mockIsCloud = false + fetchMock = vi.fn() + vi.stubGlobal('fetch', fetchMock) + createObjectURLSpy.mockClear().mockReturnValue('blob:mock-url') + revokeObjectURLSpy.mockClear().mockImplementation(() => {}) // Create a mock anchor element mockLink = { href: '', download: '', - click: vi.fn() + click: vi.fn(), + style: { display: '' } } as unknown as HTMLAnchorElement // Spy on DOM methods @@ -20,7 +43,7 @@ describe('downloadUtil', () => { }) afterEach(() => { - vi.restoreAllMocks() + vi.unstubAllGlobals() }) describe('downloadFile', () => { @@ -35,6 +58,8 @@ describe('downloadUtil', () => { expect(document.body.appendChild).toHaveBeenCalledWith(mockLink) expect(mockLink.click).toHaveBeenCalled() expect(document.body.removeChild).toHaveBeenCalledWith(mockLink) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should use custom filename when provided', () => { @@ -45,6 +70,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe(customFilename) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should extract filename from URL query parameters', () => { @@ -55,6 +82,7 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe('extracted-image.jpg') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should use default filename when URL has no filename parameter', () => { @@ -64,6 +92,7 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe('download.png') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle invalid URLs gracefully', () => { @@ -74,6 +103,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(invalidUrl) expect(mockLink.download).toBe('download.png') expect(mockLink.click).toHaveBeenCalled() + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should prefer custom filename over extracted filename', () => { @@ -84,6 +115,7 @@ describe('downloadUtil', () => { downloadFile(testUrl, customFilename) expect(mockLink.download).toBe(customFilename) + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle URLs with empty filename parameter', () => { @@ -92,6 +124,7 @@ describe('downloadUtil', () => { downloadFile(testUrl) expect(mockLink.download).toBe('download.png') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle relative URLs by using window.location.origin', () => { @@ -101,6 +134,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(relativeUrl) expect(mockLink.download).toBe('relative-image.png') + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should clean up DOM elements after download', () => { @@ -111,6 +146,54 @@ describe('downloadUtil', () => { // Verify the element was added and then removed expect(document.body.appendChild).toHaveBeenCalledWith(mockLink) expect(document.body.removeChild).toHaveBeenCalledWith(mockLink) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() + }) + + it('streams downloads via blob when running in cloud', async () => { + mockIsCloud = true + const testUrl = 'https://storage.googleapis.com/bucket/file.bin' + const blob = new Blob(['test']) + const blobFn = vi.fn().mockResolvedValue(blob) + fetchMock.mockResolvedValue({ + ok: true, + status: 200, + blob: blobFn + } as unknown as Response) + + downloadFile(testUrl) + + expect(fetchMock).toHaveBeenCalledWith(testUrl) + const fetchPromise = fetchMock.mock.results[0].value as Promise + await fetchPromise + const blobPromise = blobFn.mock.results[0].value as Promise + await blobPromise + await Promise.resolve() + expect(blobFn).toHaveBeenCalled() + expect(createObjectURLSpy).toHaveBeenCalledWith(blob) + expect(revokeObjectURLSpy).toHaveBeenCalledWith('blob:mock-url') + expect(mockLink.click).toHaveBeenCalled() + }) + + it('logs an error when cloud fetch fails', async () => { + mockIsCloud = true + const testUrl = 'https://storage.googleapis.com/bucket/missing.bin' + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + fetchMock.mockResolvedValue({ + ok: false, + status: 404, + blob: vi.fn() + } as unknown as Response) + + downloadFile(testUrl) + + expect(fetchMock).toHaveBeenCalledWith(testUrl) + const fetchPromise = fetchMock.mock.results[0].value as Promise + await fetchPromise + await Promise.resolve() + expect(consoleSpy).toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() + consoleSpy.mockRestore() }) }) })