From c72ba664ee3df3698ed7118cce53f2fb47f21d9a Mon Sep 17 00:00:00 2001
From: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
Date: Sun, 16 Mar 2025 18:50:58 +0000
Subject: [PATCH] Model file import support for desktop
---
src/App.vue | 16 +++
.../content/ImportModelDialogContent.vue | 123 ++++++++++++++++++
src/locales/en/main.json | 8 +-
src/scripts/api.ts | 12 +-
src/scripts/app.ts | 42 +++---
src/scripts/pnginfo.ts | 15 ++-
src/services/dialogService.ts | 12 ++
src/utils/safetensorsUtil.ts | 93 +++++++++++++
8 files changed, 296 insertions(+), 25 deletions(-)
create mode 100644 src/components/dialog/content/ImportModelDialogContent.vue
create mode 100644 src/utils/safetensorsUtil.ts
diff --git a/src/App.vue b/src/App.vue
index 85b36240c..2b2cf1d6b 100644
--- a/src/App.vue
+++ b/src/App.vue
@@ -16,8 +16,10 @@ import { computed, onMounted } from 'vue'
import GlobalDialog from '@/components/dialog/GlobalDialog.vue'
import config from '@/config'
+import { api } from '@/scripts/api'
import { useWorkspaceStore } from '@/stores/workspaceStore'
+import { useDialogService } from './services/dialogService'
import { electronAPI, isElectron } from './utils/envUtil'
const workspaceStore = useWorkspaceStore()
@@ -46,6 +48,20 @@ onMounted(() => {
if (isElectron()) {
document.addEventListener('contextmenu', showContextMenu)
+
+ // Handle file drops to import models via electron
+ api.addEventListener('unhandledFileDrop', async (e) => {
+ e.preventDefault() // Prevent unable to find workflow in file error
+
+ const filePath = await electronAPI()['getFilePath'](e.detail.file)
+
+ if (filePath) {
+ useDialogService().showImportModelDialog({
+ path: filePath,
+ file: e.detail.file
+ })
+ }
+ })
}
})
diff --git a/src/components/dialog/content/ImportModelDialogContent.vue b/src/components/dialog/content/ImportModelDialogContent.vue
new file mode 100644
index 000000000..fe6f9f4c0
--- /dev/null
+++ b/src/components/dialog/content/ImportModelDialogContent.vue
@@ -0,0 +1,123 @@
+
+
+
+ {{ t('importModelDialog.title') }}
+
+
{{ path }}
+
+
+
+
+
+
+
{{ importError }}
+
+
+
+
+
diff --git a/src/locales/en/main.json b/src/locales/en/main.json
index 79dbbb9d2..f1a821275 100644
--- a/src/locales/en/main.json
+++ b/src/locales/en/main.json
@@ -1405,5 +1405,11 @@
"tooltip": "Execute to selected output nodes (Highlighted with orange border)",
"disabledTooltip": "No output nodes selected"
}
+ },
+ "importModelDialog": {
+ "title": "Import Model",
+ "type": "Type",
+ "move": "Move",
+ "copy": "Copy"
}
-}
\ No newline at end of file
+}
diff --git a/src/scripts/api.ts b/src/scripts/api.ts
index ba6cb569e..4a8a0d856 100644
--- a/src/scripts/api.ts
+++ b/src/scripts/api.ts
@@ -82,6 +82,7 @@ interface QueuePromptRequestBody {
interface FrontendApiCalls {
graphChanged: ComfyWorkflowJSON
promptQueued: { number: number; batchCount: number }
+ unhandledFileDrop: { file: File }
graphCleared: never
reconnecting: never
reconnected: never
@@ -313,20 +314,23 @@ export class ComfyApi extends EventTarget {
* Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6).
* @param type The type of event to emit
* @param detail The detail property used for a custom event ({@link CustomEventInit.detail})
+ * @param init The event config used for a custom event ({@link CustomEventInit})
*/
dispatchCustomEvent(type: T): boolean
dispatchCustomEvent(
type: T,
- detail: ApiEventTypes[T] | null
+ detail: ApiEventTypes[T] | null,
+ init?: EventInit
): boolean
dispatchCustomEvent(
type: T,
- detail?: ApiEventTypes[T]
+ detail?: ApiEventTypes[T],
+ init?: EventInit
): boolean {
const event =
detail === undefined
- ? new CustomEvent(type)
- : new CustomEvent(type, { detail })
+ ? new CustomEvent(type, { ...init })
+ : new CustomEvent(type, { detail, ...init })
return super.dispatchEvent(event)
}
diff --git a/src/scripts/app.ts b/src/scripts/app.ts
index 7f4dceb4e..e52a1de20 100644
--- a/src/scripts/app.ts
+++ b/src/scripts/app.ts
@@ -1252,10 +1252,22 @@ export class ComfyApp {
return !executionStore.lastNodeErrors
}
- showErrorOnFileLoad(file: File) {
- useToastStore().addAlert(
- t('toastMessages.fileLoadError', { fileName: file.name })
+ onUnhandledFile(file: File) {
+ // Fire custom event to allow other parts of the app to handle the file
+ const unhandled = api.dispatchCustomEvent(
+ 'unhandledFileDrop',
+ { file },
+ {
+ cancelable: true
+ }
)
+
+ if (unhandled) {
+ // Nothing handled the event, so show the error dialog
+ useToastStore().addAlert(
+ t('toastMessages.fileLoadError', { fileName: file.name })
+ )
+ }
}
/**
@@ -1291,7 +1303,7 @@ export class ComfyApp {
this.graph.serialize() as unknown as ComfyWorkflowJSON
)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (file.type === 'image/webp') {
const pngInfo = await getWebpMetadata(file)
@@ -1304,7 +1316,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (file.type === 'audio/mpeg') {
const { workflow, prompt } = await getMp3Metadata(file)
@@ -1313,7 +1325,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (file.type === 'audio/ogg') {
const { workflow, prompt } = await getOggMetadata(file)
@@ -1322,7 +1334,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (file.type === 'audio/flac' || file.type === 'audio/x-flac') {
const pngInfo = await getFlacMetadata(file)
@@ -1334,7 +1346,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (file.type === 'video/webm') {
const webmInfo = await getFromWebmFile(file)
@@ -1343,7 +1355,7 @@ export class ComfyApp {
} else if (webmInfo.prompt) {
this.loadApiJson(webmInfo.prompt, fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (
file.type === 'video/mp4' ||
@@ -1366,7 +1378,7 @@ export class ComfyApp {
} else if (svgInfo.prompt) {
this.loadApiJson(svgInfo.prompt, fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (
file.type === 'model/gltf-binary' ||
@@ -1378,7 +1390,7 @@ export class ComfyApp {
} else if (gltfInfo.prompt) {
this.loadApiJson(gltfInfo.prompt, fileName)
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else if (
file.type === 'application/json' ||
@@ -1409,7 +1421,7 @@ export class ComfyApp {
const info = await getLatentMetadata(file)
// TODO define schema to LatentMetadata
// @ts-expect-error
- if (info.workflow) {
+ if (info?.workflow) {
await this.loadGraphData(
// @ts-expect-error
JSON.parse(info.workflow),
@@ -1418,14 +1430,14 @@ export class ComfyApp {
fileName
)
// @ts-expect-error
- } else if (info.prompt) {
+ } else if (info?.prompt) {
// @ts-expect-error
this.loadApiJson(JSON.parse(info.prompt))
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
} else {
- this.showErrorOnFileLoad(file)
+ this.onUnhandledFile(file)
}
}
diff --git a/src/scripts/pnginfo.ts b/src/scripts/pnginfo.ts
index 7c7f6f597..6b295a366 100644
--- a/src/scripts/pnginfo.ts
+++ b/src/scripts/pnginfo.ts
@@ -143,12 +143,17 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer)
let header_size = dataView.getUint32(0, true)
let offset = 8
- let header = JSON.parse(
- new TextDecoder().decode(
- safetensorsData.slice(offset, offset + header_size)
+ try {
+ let header = JSON.parse(
+ new TextDecoder().decode(
+ safetensorsData.slice(offset, offset + header_size)
+ )
)
- )
- r(header.__metadata__)
+ r(header.__metadata__)
+ } catch (e) {
+ // Invalid header
+ r(undefined)
+ }
}
var slice = file.slice(0, 1024 * 1024 * 4)
diff --git a/src/services/dialogService.ts b/src/services/dialogService.ts
index 7c46f5c2e..9ba97fbb1 100644
--- a/src/services/dialogService.ts
+++ b/src/services/dialogService.ts
@@ -2,6 +2,7 @@ import ApiNodesNewsContent from '@/components/dialog/content/ApiNodesNewsContent
import ApiNodesSignInContent from '@/components/dialog/content/ApiNodesSignInContent.vue'
import ConfirmationDialogContent from '@/components/dialog/content/ConfirmationDialogContent.vue'
import ErrorDialogContent from '@/components/dialog/content/ErrorDialogContent.vue'
+import ImportModelDialogContent from '@/components/dialog/content/ImportModelDialogContent.vue'
import IssueReportDialogContent from '@/components/dialog/content/IssueReportDialogContent.vue'
import LoadWorkflowWarning from '@/components/dialog/content/LoadWorkflowWarning.vue'
import ManagerProgressDialogContent from '@/components/dialog/content/ManagerProgressDialogContent.vue'
@@ -406,6 +407,16 @@ export const useDialogService = () => {
})
}
+ function showImportModelDialog(
+ props: InstanceType['$props']
+ ) {
+ dialogStore.showDialog({
+ key: 'global-import-model',
+ component: ImportModelDialogContent,
+ props
+ })
+ }
+
return {
showLoadWorkflowWarning,
showMissingModelsWarning,
@@ -422,6 +433,7 @@ export const useDialogService = () => {
showTopUpCreditsDialog,
showUpdatePasswordDialog,
showApiNodesNewsDialog,
+ showImportModelDialog,
prompt,
confirm
}
diff --git a/src/utils/safetensorsUtil.ts b/src/utils/safetensorsUtil.ts
new file mode 100644
index 000000000..6535fd89d
--- /dev/null
+++ b/src/utils/safetensorsUtil.ts
@@ -0,0 +1,93 @@
+export interface ModelSpec {
+ 'modelspec.sai_model_spec': string
+ 'modelspec.architecture': string
+ 'modelspec.title': string
+ 'modelspec.description': string
+}
+
+const architectureToType: Record = {
+ 'stable-diffusion-v1': 'checkpoints',
+ 'stable-diffusion-xl-v1-base': 'checkpoints',
+ 'Flux.1-schnell': 'checkpoints',
+ 'Flux.1-dev': 'checkpoints',
+ 'Flux.1-AE': 'vae'
+}
+
+interface SafetensorsHeader | ModelSpec> {
+ [k: string]: unknown
+ __metadata__?: TMetadata
+}
+
+export async function guessModelType(file: File): Promise {
+ const header = await getHeader(file)
+ if (!header) return null
+
+ let suggestedType: string | null
+ if (isModelSpec(header)) {
+ suggestedType = guessFromModelSpec(header)
+ }
+
+ suggestedType ??= guessFromHeaderKeys(header)
+
+ return suggestedType
+}
+
+async function getHeader(file: File): Promise {
+ try {
+ // 8 bytes: an unsigned little-endian 64-bit integer, containing the size of the header
+ // Slice the first 8 bytes so we don't read the whole file
+ const headerSizeBlob = file.slice(0, 8)
+ const headerSizeView = new DataView(await headerSizeBlob.arrayBuffer())
+ const headerSize = headerSizeView.getBigUint64(0, true)
+
+ if (
+ headerSize < 0 ||
+ headerSize > file.size ||
+ headerSize > Number.MAX_SAFE_INTEGER
+ ) {
+ // Invalid header, probably not a safetensors file
+ console.log(`Invalid header size ${headerSize} for file '${file.name}'`)
+ return null
+ }
+
+ // N bytes: a JSON UTF-8 string representing the header.
+ const header = file.slice(8, Number(headerSize) + 8)
+ const content = await header.text()
+ return JSON.parse(content)
+ } catch (error) {
+ // Error reading the file, probably not a safetensors file
+ console.error(`Error reading safetensors header '${file.name}'`, error)
+ return null
+ }
+}
+
+function guessFromModelSpec(header: SafetensorsHeader) {
+ const architecture = header.__metadata__?.['modelspec.architecture']
+ if (!architecture) return null
+ let suggestedType = architectureToType[architecture]
+ if (!suggestedType) {
+ if (architecture?.endsWith('/lora')) {
+ suggestedType = 'loras'
+ }
+ }
+
+ return suggestedType
+}
+
+function guessFromHeaderKeys(header: SafetensorsHeader) {
+ let suggestedType: string | null = null
+ const keys = Object.keys(header)
+ if (keys.find((k) => k.startsWith('lora_unet_'))) {
+ suggestedType = 'loras'
+ } else if (keys.find((k) => k.startsWith('model.diffusion_model.'))) {
+ suggestedType = 'checkpoints'
+ }
+
+ return suggestedType
+}
+
+function isModelSpec(
+ header: SafetensorsHeader
+): header is SafetensorsHeader {
+ return !!header.__metadata__?.['modelspec.sai_model_spec']
+}