From 4e5eba6c54404c9105ec0c12a844729d7b90c36b Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 23 Oct 2025 12:08:30 -0700 Subject: [PATCH] refactor: centralize all download utils across app and apply special cloud-specific behavior (#6188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Centralized all download functionalities across app. Then changed downloadFile on the cloud distribution to stream assets via blob fetches while desktop/local retains direct anchor downloads. This fixes issue where trying to download cross-origin resources opens them in the window, potentially losing the user's unsaved changes. ## Changes - **What**: Moved `downloadBlob` into `downloadUtil`, routed all callers (3D exporter, recording manager, node template export, workflow/palette export, Litegraph save, ~~`useDownload` consumers~~) through shared helpers, and changed `downloadFile` to `fetch` first when `isCloud` so cross-origin URLs download reliably - `useDownload` is the exception since we simply cannot do model downloads through blob (forcing user to transfer the entire model data twice is bad). Fortunately on cloud, the user doesn't need to download models locally anyway. ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-6188-refactor-centralize-all-download-utils-across-app-and-apply-special-cloud-specific-behav-2946d73d365081de9f27f0994950511d) by [Unito](https://www.unito.io) --- src/base/common/downloadUtil.ts | 63 ++++++++++++-- src/extensions/core/load3d/ModelExporter.ts | 21 +---- .../core/load3d/RecordingManager.ts | 14 +-- src/extensions/core/nodeTemplates.ts | 30 ++----- .../workflow/core/services/workflowService.ts | 2 +- src/scripts/utils.ts | 16 +--- src/services/colorPaletteService.ts | 3 +- src/services/litegraphService.ts | 13 +-- .../tests/base/common/downloadUtil.test.ts | 87 ++++++++++++++++++- 9 files changed, 160 insertions(+), 89 deletions(-) diff --git a/src/base/common/downloadUtil.ts b/src/base/common/downloadUtil.ts index 307a3e35b..4462dd1a7 100644 --- a/src/base/common/downloadUtil.ts +++ b/src/base/common/downloadUtil.ts @@ -1,29 +1,64 @@ /** * Utility functions for downloading files */ +import { isCloud } from '@/platform/distribution/types' // Constants const DEFAULT_DOWNLOAD_FILENAME = 'download.png' +/** + * Trigger a download by creating a temporary anchor element + * @param href - The URL or blob URL to download + * @param filename - The filename to suggest to the browser + */ +function triggerLinkDownload(href: string, filename: string): void { + const link = document.createElement('a') + link.href = href + link.download = filename + link.style.display = 'none' + + document.body.appendChild(link) + link.click() + document.body.removeChild(link) +} + /** * Download a file from a URL by creating a temporary anchor element * @param url - The URL of the file to download (must be a valid URL string) * @param filename - Optional filename override (will use URL filename or default if not provided) * @throws {Error} If the URL is invalid or empty */ -export const downloadFile = (url: string, filename?: string): void => { +export function downloadFile(url: string, filename?: string): void { if (!url || typeof url !== 'string' || url.trim().length === 0) { throw new Error('Invalid URL provided for download') } - const link = document.createElement('a') - link.href = url - link.download = + + const inferredFilename = filename || extractFilenameFromUrl(url) || DEFAULT_DOWNLOAD_FILENAME - // Trigger download - document.body.appendChild(link) - link.click() - document.body.removeChild(link) + if (isCloud) { + // Assets from cross-origin (e.g., GCS) cannot be downloaded this way + void downloadViaBlobFetch(url, inferredFilename).catch((error) => { + console.error('Failed to download file', error) + }) + return + } + + triggerLinkDownload(url, inferredFilename) +} + +/** + * Download a Blob by creating a temporary object URL and anchor element + * @param filename - The filename to suggest to the browser + * @param blob - The Blob to download + */ +export function downloadBlob(filename: string, blob: Blob): void { + const url = URL.createObjectURL(blob) + + triggerLinkDownload(url, filename) + + // Revoke on the next microtask to give the browser time to start the download + queueMicrotask(() => URL.revokeObjectURL(url)) } /** @@ -39,3 +74,15 @@ const extractFilenameFromUrl = (url: string): string | null => { return null } } + +const downloadViaBlobFetch = async ( + href: string, + filename: string +): Promise => { + const response = await fetch(href) + if (!response.ok) { + throw new Error(`Failed to fetch ${href}: ${response.status}`) + } + const blob = await response.blob() + downloadBlob(filename, blob) +} diff --git a/src/extensions/core/load3d/ModelExporter.ts b/src/extensions/core/load3d/ModelExporter.ts index bf677ba44..f3ff3a443 100644 --- a/src/extensions/core/load3d/ModelExporter.ts +++ b/src/extensions/core/load3d/ModelExporter.ts @@ -3,6 +3,7 @@ import { GLTFExporter } from 'three/examples/jsm/exporters/GLTFExporter' import { OBJExporter } from 'three/examples/jsm/exporters/OBJExporter' import { STLExporter } from 'three/examples/jsm/exporters/STLExporter' +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { useToastStore } from '@/platform/updates/common/toastStore' @@ -38,13 +39,7 @@ export class ModelExporter { try { const response = await fetch(url) const blob = await response.blob() - - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = desiredFilename - link.click() - - URL.revokeObjectURL(link.href) + downloadBlob(desiredFilename, blob) } catch (error) { console.error('Error downloading from URL:', error) useToastStore().addAlert(t('toastMessages.failedToDownloadFile')) @@ -152,19 +147,11 @@ export class ModelExporter { private static saveArrayBuffer(buffer: ArrayBuffer, filename: string): void { const blob = new Blob([buffer], { type: 'application/octet-stream' }) - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = filename - link.click() - URL.revokeObjectURL(link.href) + downloadBlob(filename, blob) } private static saveString(text: string, filename: string): void { const blob = new Blob([text], { type: 'text/plain' }) - const link = document.createElement('a') - link.href = URL.createObjectURL(blob) - link.download = filename - link.click() - URL.revokeObjectURL(link.href) + downloadBlob(filename, blob) } } diff --git a/src/extensions/core/load3d/RecordingManager.ts b/src/extensions/core/load3d/RecordingManager.ts index 679fa9c5d..bf75ff2e8 100644 --- a/src/extensions/core/load3d/RecordingManager.ts +++ b/src/extensions/core/load3d/RecordingManager.ts @@ -1,5 +1,7 @@ import * as THREE from 'three' +import { downloadBlob } from '@/base/common/downloadUtil' + import { type EventManagerInterface } from './interfaces' export class RecordingManager { @@ -149,17 +151,7 @@ export class RecordingManager { try { const blob = new Blob(this.recordedChunks, { type: 'video/webm' }) - - const url = URL.createObjectURL(blob) - const a = document.createElement('a') - document.body.appendChild(a) - a.style.display = 'none' - a.href = url - a.download = filename - a.click() - - window.URL.revokeObjectURL(url) - document.body.removeChild(a) + downloadBlob(filename, blob) this.eventManager.emitEvent('recordingExported', null) } catch (error) { diff --git a/src/extensions/core/nodeTemplates.ts b/src/extensions/core/nodeTemplates.ts index 5737ba54d..528f73a5f 100644 --- a/src/extensions/core/nodeTemplates.ts +++ b/src/extensions/core/nodeTemplates.ts @@ -1,3 +1,4 @@ +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { LGraphCanvas } from '@/lib/litegraph/src/litegraph' import { useToastStore } from '@/platform/updates/common/toastStore' @@ -145,18 +146,7 @@ class ManageTemplates extends ComfyDialog { const json = JSON.stringify({ templates: this.templates }, null, 2) // convert the data to a JSON string const blob = new Blob([json], { type: 'application/json' }) - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - download: 'node_templates.json', - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) + downloadBlob('node_templates.json', blob) } override show() { @@ -298,19 +288,9 @@ class ManageTemplates extends ComfyDialog { const blob = new Blob([json], { type: 'application/json' }) - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - // @ts-expect-error fixme ts strict error - download: (nameInput.value || t.name) + '.json', - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) + // @ts-expect-error fixme ts strict error + const name = (nameInput.value || t.name) + '.json' + downloadBlob(name, blob) } }), $el('button', { diff --git a/src/platform/workflow/core/services/workflowService.ts b/src/platform/workflow/core/services/workflowService.ts index c4234cf45..2eba1345b 100644 --- a/src/platform/workflow/core/services/workflowService.ts +++ b/src/platform/workflow/core/services/workflowService.ts @@ -1,5 +1,6 @@ import { toRaw } from 'vue' +import { downloadBlob } from '@/base/common/downloadUtil' import { t } from '@/i18n' import { LGraph, LGraphCanvas } from '@/lib/litegraph/src/litegraph' import type { Point, SerialisableGraph } from '@/lib/litegraph/src/litegraph' @@ -13,7 +14,6 @@ import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/w import { useWorkflowThumbnail } from '@/renderer/core/thumbnail/useWorkflowThumbnail' import { app } from '@/scripts/app' import { blankGraph, defaultGraph } from '@/scripts/defaultGraph' -import { downloadBlob } from '@/scripts/utils' import { useDialogService } from '@/services/dialogService' import { useDomWidgetStore } from '@/stores/domWidgetStore' import { useWorkspaceStore } from '@/stores/workspaceStore' diff --git a/src/scripts/utils.ts b/src/scripts/utils.ts index 59556523b..a1bcc3d89 100644 --- a/src/scripts/utils.ts +++ b/src/scripts/utils.ts @@ -51,20 +51,8 @@ export async function addStylesheet( }) } -export function downloadBlob(filename: string, blob: Blob) { - const url = URL.createObjectURL(blob) - const a = $el('a', { - href: url, - download: filename, - style: { display: 'none' }, - parent: document.body - }) - a.click() - setTimeout(function () { - a.remove() - window.URL.revokeObjectURL(url) - }, 0) -} +/** @knipIgnoreUnusedButUsedByCustomNodes */ +export { downloadBlob } from '@/base/common/downloadUtil' export function uploadFile(accept: string) { return new Promise((resolve, reject) => { diff --git a/src/services/colorPaletteService.ts b/src/services/colorPaletteService.ts index bfbfeb420..ddaa6d702 100644 --- a/src/services/colorPaletteService.ts +++ b/src/services/colorPaletteService.ts @@ -1,13 +1,14 @@ import { toRaw } from 'vue' import { fromZodError } from 'zod-validation-error' +import { downloadBlob } from '@/base/common/downloadUtil' import { useErrorHandling } from '@/composables/useErrorHandling' import { LGraphCanvas, LiteGraph } from '@/lib/litegraph/src/litegraph' import { useSettingStore } from '@/platform/settings/settingStore' import { paletteSchema } from '@/schemas/colorPaletteSchema' import type { Colors, Palette } from '@/schemas/colorPaletteSchema' import { app } from '@/scripts/app' -import { downloadBlob, uploadFile } from '@/scripts/utils' +import { uploadFile } from '@/scripts/utils' import { useNodeDefStore } from '@/stores/nodeDefStore' import { useColorPaletteStore } from '@/stores/workspace/colorPaletteStore' diff --git a/src/services/litegraphService.ts b/src/services/litegraphService.ts index 6c84d7db2..9bc35e0b9 100644 --- a/src/services/litegraphService.ts +++ b/src/services/litegraphService.ts @@ -1,5 +1,6 @@ import _ from 'es-toolkit/compat' +import { downloadFile } from '@/base/common/downloadUtil' import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems' import { useNodeAnimatedImage } from '@/composables/node/useNodeAnimatedImage' import { useNodeCanvasImagePreview } from '@/composables/node/useNodeCanvasImagePreview' @@ -756,18 +757,10 @@ export const useLitegraphService = () => { { content: 'Save Image', callback: () => { - const a = document.createElement('a') const url = new URL(img.src) url.searchParams.delete('preview') - a.href = url.toString() - a.setAttribute( - 'download', - // @ts-expect-error fixme ts strict error - new URLSearchParams(url.search).get('filename') - ) - document.body.append(a) - a.click() - requestAnimationFrame(() => a.remove()) + const filename = new URLSearchParams(url.search).get('filename') + downloadFile(url.toString(), filename ?? undefined) } } ) diff --git a/tests-ui/tests/base/common/downloadUtil.test.ts b/tests-ui/tests/base/common/downloadUtil.test.ts index 7231b3620..8684c8e1c 100644 --- a/tests-ui/tests/base/common/downloadUtil.test.ts +++ b/tests-ui/tests/base/common/downloadUtil.test.ts @@ -2,15 +2,38 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { downloadFile } from '@/base/common/downloadUtil' +let mockIsCloud = false + +vi.mock('@/platform/distribution/types', () => ({ + get isCloud() { + return mockIsCloud + } +})) + +// Global stubs +const createObjectURLSpy = vi + .spyOn(URL, 'createObjectURL') + .mockReturnValue('blob:mock-url') +const revokeObjectURLSpy = vi + .spyOn(URL, 'revokeObjectURL') + .mockImplementation(() => {}) + describe('downloadUtil', () => { let mockLink: HTMLAnchorElement + let fetchMock: ReturnType beforeEach(() => { + mockIsCloud = false + fetchMock = vi.fn() + vi.stubGlobal('fetch', fetchMock) + createObjectURLSpy.mockClear().mockReturnValue('blob:mock-url') + revokeObjectURLSpy.mockClear().mockImplementation(() => {}) // Create a mock anchor element mockLink = { href: '', download: '', - click: vi.fn() + click: vi.fn(), + style: { display: '' } } as unknown as HTMLAnchorElement // Spy on DOM methods @@ -20,7 +43,7 @@ describe('downloadUtil', () => { }) afterEach(() => { - vi.restoreAllMocks() + vi.unstubAllGlobals() }) describe('downloadFile', () => { @@ -35,6 +58,8 @@ describe('downloadUtil', () => { expect(document.body.appendChild).toHaveBeenCalledWith(mockLink) expect(mockLink.click).toHaveBeenCalled() expect(document.body.removeChild).toHaveBeenCalledWith(mockLink) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should use custom filename when provided', () => { @@ -45,6 +70,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe(customFilename) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should extract filename from URL query parameters', () => { @@ -55,6 +82,7 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe('extracted-image.jpg') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should use default filename when URL has no filename parameter', () => { @@ -64,6 +92,7 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(testUrl) expect(mockLink.download).toBe('download.png') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle invalid URLs gracefully', () => { @@ -74,6 +103,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(invalidUrl) expect(mockLink.download).toBe('download.png') expect(mockLink.click).toHaveBeenCalled() + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should prefer custom filename over extracted filename', () => { @@ -84,6 +115,7 @@ describe('downloadUtil', () => { downloadFile(testUrl, customFilename) expect(mockLink.download).toBe(customFilename) + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle URLs with empty filename parameter', () => { @@ -92,6 +124,7 @@ describe('downloadUtil', () => { downloadFile(testUrl) expect(mockLink.download).toBe('download.png') + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should handle relative URLs by using window.location.origin', () => { @@ -101,6 +134,8 @@ describe('downloadUtil', () => { expect(mockLink.href).toBe(relativeUrl) expect(mockLink.download).toBe('relative-image.png') + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() }) it('should clean up DOM elements after download', () => { @@ -111,6 +146,54 @@ describe('downloadUtil', () => { // Verify the element was added and then removed expect(document.body.appendChild).toHaveBeenCalledWith(mockLink) expect(document.body.removeChild).toHaveBeenCalledWith(mockLink) + expect(fetchMock).not.toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() + }) + + it('streams downloads via blob when running in cloud', async () => { + mockIsCloud = true + const testUrl = 'https://storage.googleapis.com/bucket/file.bin' + const blob = new Blob(['test']) + const blobFn = vi.fn().mockResolvedValue(blob) + fetchMock.mockResolvedValue({ + ok: true, + status: 200, + blob: blobFn + } as unknown as Response) + + downloadFile(testUrl) + + expect(fetchMock).toHaveBeenCalledWith(testUrl) + const fetchPromise = fetchMock.mock.results[0].value as Promise + await fetchPromise + const blobPromise = blobFn.mock.results[0].value as Promise + await blobPromise + await Promise.resolve() + expect(blobFn).toHaveBeenCalled() + expect(createObjectURLSpy).toHaveBeenCalledWith(blob) + expect(revokeObjectURLSpy).toHaveBeenCalledWith('blob:mock-url') + expect(mockLink.click).toHaveBeenCalled() + }) + + it('logs an error when cloud fetch fails', async () => { + mockIsCloud = true + const testUrl = 'https://storage.googleapis.com/bucket/missing.bin' + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + fetchMock.mockResolvedValue({ + ok: false, + status: 404, + blob: vi.fn() + } as unknown as Response) + + downloadFile(testUrl) + + expect(fetchMock).toHaveBeenCalledWith(testUrl) + const fetchPromise = fetchMock.mock.results[0].value as Promise + await fetchPromise + await Promise.resolve() + expect(consoleSpy).toHaveBeenCalled() + expect(createObjectURLSpy).not.toHaveBeenCalled() + consoleSpy.mockRestore() }) }) })