diff --git a/src/App.vue b/src/App.vue index 85b36240c..2b2cf1d6b 100644 --- a/src/App.vue +++ b/src/App.vue @@ -16,8 +16,10 @@ import { computed, onMounted } from 'vue' import GlobalDialog from '@/components/dialog/GlobalDialog.vue' import config from '@/config' +import { api } from '@/scripts/api' import { useWorkspaceStore } from '@/stores/workspaceStore' +import { useDialogService } from './services/dialogService' import { electronAPI, isElectron } from './utils/envUtil' const workspaceStore = useWorkspaceStore() @@ -46,6 +48,20 @@ onMounted(() => { if (isElectron()) { document.addEventListener('contextmenu', showContextMenu) + + // Handle file drops to import models via electron + api.addEventListener('unhandledFileDrop', async (e) => { + e.preventDefault() // Prevent unable to find workflow in file error + + const filePath = await electronAPI()['getFilePath'](e.detail.file) + + if (filePath) { + useDialogService().showImportModelDialog({ + path: filePath, + file: e.detail.file + }) + } + }) } }) diff --git a/src/components/dialog/content/ImportModelDialogContent.vue b/src/components/dialog/content/ImportModelDialogContent.vue new file mode 100644 index 000000000..fe6f9f4c0 --- /dev/null +++ b/src/components/dialog/content/ImportModelDialogContent.vue @@ -0,0 +1,123 @@ + + + diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 79dbbb9d2..f1a821275 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -1405,5 +1405,11 @@ "tooltip": "Execute to selected output nodes (Highlighted with orange border)", "disabledTooltip": "No output nodes selected" } + }, + "importModelDialog": { + "title": "Import Model", + "type": "Type", + "move": "Move", + "copy": "Copy" } -} \ No newline at end of file +} diff --git a/src/scripts/api.ts b/src/scripts/api.ts index ba6cb569e..4a8a0d856 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -82,6 +82,7 @@ interface QueuePromptRequestBody { interface FrontendApiCalls { graphChanged: ComfyWorkflowJSON promptQueued: { number: number; batchCount: number } + unhandledFileDrop: { file: File } graphCleared: never reconnecting: never reconnected: never @@ -313,20 +314,23 @@ export class ComfyApi extends EventTarget { * Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6). * @param type The type of event to emit * @param detail The detail property used for a custom event ({@link CustomEventInit.detail}) + * @param init The event config used for a custom event ({@link CustomEventInit}) */ dispatchCustomEvent(type: T): boolean dispatchCustomEvent( type: T, - detail: ApiEventTypes[T] | null + detail: ApiEventTypes[T] | null, + init?: EventInit ): boolean dispatchCustomEvent( type: T, - detail?: ApiEventTypes[T] + detail?: ApiEventTypes[T], + init?: EventInit ): boolean { const event = detail === undefined - ? new CustomEvent(type) - : new CustomEvent(type, { detail }) + ? new CustomEvent(type, { ...init }) + : new CustomEvent(type, { detail, ...init }) return super.dispatchEvent(event) } diff --git a/src/scripts/app.ts b/src/scripts/app.ts index 7f4dceb4e..e52a1de20 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -1252,10 +1252,22 @@ export class ComfyApp { return !executionStore.lastNodeErrors } - showErrorOnFileLoad(file: File) { - useToastStore().addAlert( - t('toastMessages.fileLoadError', { fileName: file.name }) + onUnhandledFile(file: File) { + // Fire custom event to allow other parts of the app to handle the file + const unhandled = api.dispatchCustomEvent( + 'unhandledFileDrop', + { file }, + { + cancelable: true + } ) + + if (unhandled) { + // Nothing handled the event, so show the error dialog + useToastStore().addAlert( + t('toastMessages.fileLoadError', { fileName: file.name }) + ) + } } /** @@ -1291,7 +1303,7 @@ export class ComfyApp { this.graph.serialize() as unknown as ComfyWorkflowJSON ) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if (file.type === 'image/webp') { const pngInfo = await getWebpMetadata(file) @@ -1304,7 +1316,7 @@ export class ComfyApp { } else if (prompt) { this.loadApiJson(JSON.parse(prompt), fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if (file.type === 'audio/mpeg') { const { workflow, prompt } = await getMp3Metadata(file) @@ -1313,7 +1325,7 @@ export class ComfyApp { } else if (prompt) { this.loadApiJson(prompt, fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if (file.type === 'audio/ogg') { const { workflow, prompt } = await getOggMetadata(file) @@ -1322,7 +1334,7 @@ export class ComfyApp { } else if (prompt) { this.loadApiJson(prompt, fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if (file.type === 'audio/flac' || file.type === 'audio/x-flac') { const pngInfo = await getFlacMetadata(file) @@ -1334,7 +1346,7 @@ export class ComfyApp { } else if (prompt) { this.loadApiJson(JSON.parse(prompt), fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if (file.type === 'video/webm') { const webmInfo = await getFromWebmFile(file) @@ -1343,7 +1355,7 @@ export class ComfyApp { } else if (webmInfo.prompt) { this.loadApiJson(webmInfo.prompt, fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if ( file.type === 'video/mp4' || @@ -1366,7 +1378,7 @@ export class ComfyApp { } else if (svgInfo.prompt) { this.loadApiJson(svgInfo.prompt, fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if ( file.type === 'model/gltf-binary' || @@ -1378,7 +1390,7 @@ export class ComfyApp { } else if (gltfInfo.prompt) { this.loadApiJson(gltfInfo.prompt, fileName) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else if ( file.type === 'application/json' || @@ -1409,7 +1421,7 @@ export class ComfyApp { const info = await getLatentMetadata(file) // TODO define schema to LatentMetadata // @ts-expect-error - if (info.workflow) { + if (info?.workflow) { await this.loadGraphData( // @ts-expect-error JSON.parse(info.workflow), @@ -1418,14 +1430,14 @@ export class ComfyApp { fileName ) // @ts-expect-error - } else if (info.prompt) { + } else if (info?.prompt) { // @ts-expect-error this.loadApiJson(JSON.parse(info.prompt)) } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } else { - this.showErrorOnFileLoad(file) + this.onUnhandledFile(file) } } diff --git a/src/scripts/pnginfo.ts b/src/scripts/pnginfo.ts index 7c7f6f597..6b295a366 100644 --- a/src/scripts/pnginfo.ts +++ b/src/scripts/pnginfo.ts @@ -143,12 +143,17 @@ export function getLatentMetadata(file) { const dataView = new DataView(safetensorsData.buffer) let header_size = dataView.getUint32(0, true) let offset = 8 - let header = JSON.parse( - new TextDecoder().decode( - safetensorsData.slice(offset, offset + header_size) + try { + let header = JSON.parse( + new TextDecoder().decode( + safetensorsData.slice(offset, offset + header_size) + ) ) - ) - r(header.__metadata__) + r(header.__metadata__) + } catch (e) { + // Invalid header + r(undefined) + } } var slice = file.slice(0, 1024 * 1024 * 4) diff --git a/src/services/dialogService.ts b/src/services/dialogService.ts index 7c46f5c2e..9ba97fbb1 100644 --- a/src/services/dialogService.ts +++ b/src/services/dialogService.ts @@ -2,6 +2,7 @@ import ApiNodesNewsContent from '@/components/dialog/content/ApiNodesNewsContent import ApiNodesSignInContent from '@/components/dialog/content/ApiNodesSignInContent.vue' import ConfirmationDialogContent from '@/components/dialog/content/ConfirmationDialogContent.vue' import ErrorDialogContent from '@/components/dialog/content/ErrorDialogContent.vue' +import ImportModelDialogContent from '@/components/dialog/content/ImportModelDialogContent.vue' import IssueReportDialogContent from '@/components/dialog/content/IssueReportDialogContent.vue' import LoadWorkflowWarning from '@/components/dialog/content/LoadWorkflowWarning.vue' import ManagerProgressDialogContent from '@/components/dialog/content/ManagerProgressDialogContent.vue' @@ -406,6 +407,16 @@ export const useDialogService = () => { }) } + function showImportModelDialog( + props: InstanceType['$props'] + ) { + dialogStore.showDialog({ + key: 'global-import-model', + component: ImportModelDialogContent, + props + }) + } + return { showLoadWorkflowWarning, showMissingModelsWarning, @@ -422,6 +433,7 @@ export const useDialogService = () => { showTopUpCreditsDialog, showUpdatePasswordDialog, showApiNodesNewsDialog, + showImportModelDialog, prompt, confirm } diff --git a/src/utils/safetensorsUtil.ts b/src/utils/safetensorsUtil.ts new file mode 100644 index 000000000..6535fd89d --- /dev/null +++ b/src/utils/safetensorsUtil.ts @@ -0,0 +1,93 @@ +export interface ModelSpec { + 'modelspec.sai_model_spec': string + 'modelspec.architecture': string + 'modelspec.title': string + 'modelspec.description': string +} + +const architectureToType: Record = { + 'stable-diffusion-v1': 'checkpoints', + 'stable-diffusion-xl-v1-base': 'checkpoints', + 'Flux.1-schnell': 'checkpoints', + 'Flux.1-dev': 'checkpoints', + 'Flux.1-AE': 'vae' +} + +interface SafetensorsHeader | ModelSpec> { + [k: string]: unknown + __metadata__?: TMetadata +} + +export async function guessModelType(file: File): Promise { + const header = await getHeader(file) + if (!header) return null + + let suggestedType: string | null + if (isModelSpec(header)) { + suggestedType = guessFromModelSpec(header) + } + + suggestedType ??= guessFromHeaderKeys(header) + + return suggestedType +} + +async function getHeader(file: File): Promise { + try { + // 8 bytes: an unsigned little-endian 64-bit integer, containing the size of the header + // Slice the first 8 bytes so we don't read the whole file + const headerSizeBlob = file.slice(0, 8) + const headerSizeView = new DataView(await headerSizeBlob.arrayBuffer()) + const headerSize = headerSizeView.getBigUint64(0, true) + + if ( + headerSize < 0 || + headerSize > file.size || + headerSize > Number.MAX_SAFE_INTEGER + ) { + // Invalid header, probably not a safetensors file + console.log(`Invalid header size ${headerSize} for file '${file.name}'`) + return null + } + + // N bytes: a JSON UTF-8 string representing the header. + const header = file.slice(8, Number(headerSize) + 8) + const content = await header.text() + return JSON.parse(content) + } catch (error) { + // Error reading the file, probably not a safetensors file + console.error(`Error reading safetensors header '${file.name}'`, error) + return null + } +} + +function guessFromModelSpec(header: SafetensorsHeader) { + const architecture = header.__metadata__?.['modelspec.architecture'] + if (!architecture) return null + let suggestedType = architectureToType[architecture] + if (!suggestedType) { + if (architecture?.endsWith('/lora')) { + suggestedType = 'loras' + } + } + + return suggestedType +} + +function guessFromHeaderKeys(header: SafetensorsHeader) { + let suggestedType: string | null = null + const keys = Object.keys(header) + if (keys.find((k) => k.startsWith('lora_unet_'))) { + suggestedType = 'loras' + } else if (keys.find((k) => k.startsWith('model.diffusion_model.'))) { + suggestedType = 'checkpoints' + } + + return suggestedType +} + +function isModelSpec( + header: SafetensorsHeader +): header is SafetensorsHeader { + return !!header.__metadata__?.['modelspec.sai_model_spec'] +}