Model file import support for desktop

This commit is contained in:
pythongosssss
2025-03-16 18:50:58 +00:00
parent 6ed870d431
commit c72ba664ee
8 changed files with 296 additions and 25 deletions

View File

@@ -16,8 +16,10 @@ import { computed, onMounted } from 'vue'
import GlobalDialog from '@/components/dialog/GlobalDialog.vue'
import config from '@/config'
import { api } from '@/scripts/api'
import { useWorkspaceStore } from '@/stores/workspaceStore'
import { useDialogService } from './services/dialogService'
import { electronAPI, isElectron } from './utils/envUtil'
const workspaceStore = useWorkspaceStore()
@@ -46,6 +48,20 @@ onMounted(() => {
if (isElectron()) {
document.addEventListener('contextmenu', showContextMenu)
// Handle file drops to import models via electron
api.addEventListener('unhandledFileDrop', async (e) => {
e.preventDefault() // Prevent unable to find workflow in file error
const filePath = await electronAPI()['getFilePath'](e.detail.file)
if (filePath) {
useDialogService().showImportModelDialog({
path: filePath,
file: e.detail.file
})
}
})
}
})
</script>

View File

@@ -0,0 +1,123 @@
<template>
<div class="px-4 py-2 h-full gap-2 flex flex-col">
<h2 class="text-4xl font-normal my-0">
{{ t('importModelDialog.title') }}
</h2>
<span class="text-muted">{{ path }}</span>
<div class="flex flex-col gap-2 mt-4">
<IftaLabel>
<Select
v-model="selectedType"
:options="modelFolders"
editable
filter
labelId="model-type"
:disabled="importing"
/>
<label for="model-type">Type</label>
</IftaLabel>
</div>
<Message severity="error" v-if="importError">{{ importError }}</Message>
</div>
<footer>
<div class="flex justify-between gap-2 p-4">
<SelectButton
v-model="selectedImportMode"
optionLabel="label"
optionValue="value"
:options="importModes"
:disabled="importing"
/>
<div class="flex gap-2">
<Button
type="button"
label="Cancel"
severity="secondary"
@click="dialogStore.closeDialog()"
:disabled="importing"
></Button>
<Button
type="button"
label="Import"
@click="importModel()"
:icon="importIcon"
:loading="importing"
:disabled="!selectedType"
></Button>
</div>
</div>
</footer>
</template>
<script setup lang="ts">
import Button from 'primevue/button'
import IftaLabel from 'primevue/iftalabel'
import Message from 'primevue/message'
import Select from 'primevue/select'
import SelectButton from 'primevue/selectbutton'
import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { useCommandStore } from '@/stores/commandStore'
import { useDialogStore } from '@/stores/dialogStore'
import { useModelStore } from '@/stores/modelStore'
import { electronAPI } from '@/utils/envUtil'
import { guessModelType } from '@/utils/safetensorsUtil'
const { t } = useI18n()
const dialogStore = useDialogStore()
const { path, file } = defineProps<{
path: string
file: File
}>()
const importModes = ref([
{ label: t('importModelDialog.move'), value: 'move' },
{ label: t('importModelDialog.copy'), value: 'copy' }
])
const modelStore = useModelStore()
const modelFolders = ref<string[]>()
const selectedType = ref<string>()
const selectedImportMode = ref<string>('move')
const importing = ref<boolean>(false)
const importError = ref<string>()
const importIcon = computed(() => {
return selectedImportMode.value === 'move'
? 'pi pi-file-import'
: 'pi pi-copy'
})
const importModel = async () => {
importing.value = true
try {
await electronAPI()?.['importModel'](
file,
selectedType.value,
selectedImportMode.value
)
await useCommandStore().execute('Comfy.RefreshNodeDefinitions')
dialogStore.closeDialog()
} catch (error) {
console.error(error)
importError.value = error.message
} finally {
importing.value = false
}
}
const init = async () => {
if (!modelStore.modelFolders.length) {
await modelStore.loadModelFolders()
}
modelFolders.value = modelStore.modelFolders.map((folder) => folder.directory)
const type = await guessModelType(file)
if (!selectedType.value) {
selectedType.value = type
}
}
init()
</script>

View File

@@ -1405,5 +1405,11 @@
"tooltip": "Execute to selected output nodes (Highlighted with orange border)",
"disabledTooltip": "No output nodes selected"
}
},
"importModelDialog": {
"title": "Import Model",
"type": "Type",
"move": "Move",
"copy": "Copy"
}
}
}

View File

@@ -82,6 +82,7 @@ interface QueuePromptRequestBody {
interface FrontendApiCalls {
graphChanged: ComfyWorkflowJSON
promptQueued: { number: number; batchCount: number }
unhandledFileDrop: { file: File }
graphCleared: never
reconnecting: never
reconnected: never
@@ -313,20 +314,23 @@ export class ComfyApi extends EventTarget {
* Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6).
* @param type The type of event to emit
* @param detail The detail property used for a custom event ({@link CustomEventInit.detail})
* @param init The event config used for a custom event ({@link CustomEventInit})
*/
dispatchCustomEvent<T extends SimpleApiEvents>(type: T): boolean
dispatchCustomEvent<T extends ComplexApiEvents>(
type: T,
detail: ApiEventTypes[T] | null
detail: ApiEventTypes[T] | null,
init?: EventInit
): boolean
dispatchCustomEvent<T extends keyof ApiEventTypes>(
type: T,
detail?: ApiEventTypes[T]
detail?: ApiEventTypes[T],
init?: EventInit
): boolean {
const event =
detail === undefined
? new CustomEvent(type)
: new CustomEvent(type, { detail })
? new CustomEvent(type, { ...init })
: new CustomEvent(type, { detail, ...init })
return super.dispatchEvent(event)
}

View File

@@ -1252,10 +1252,22 @@ export class ComfyApp {
return !executionStore.lastNodeErrors
}
showErrorOnFileLoad(file: File) {
useToastStore().addAlert(
t('toastMessages.fileLoadError', { fileName: file.name })
onUnhandledFile(file: File) {
// Fire custom event to allow other parts of the app to handle the file
const unhandled = api.dispatchCustomEvent(
'unhandledFileDrop',
{ file },
{
cancelable: true
}
)
if (unhandled) {
// Nothing handled the event, so show the error dialog
useToastStore().addAlert(
t('toastMessages.fileLoadError', { fileName: file.name })
)
}
}
/**
@@ -1291,7 +1303,7 @@ export class ComfyApp {
this.graph.serialize() as unknown as ComfyWorkflowJSON
)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'image/webp') {
const pngInfo = await getWebpMetadata(file)
@@ -1304,7 +1316,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/mpeg') {
const { workflow, prompt } = await getMp3Metadata(file)
@@ -1313,7 +1325,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/ogg') {
const { workflow, prompt } = await getOggMetadata(file)
@@ -1322,7 +1334,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/flac' || file.type === 'audio/x-flac') {
const pngInfo = await getFlacMetadata(file)
@@ -1334,7 +1346,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'video/webm') {
const webmInfo = await getFromWebmFile(file)
@@ -1343,7 +1355,7 @@ export class ComfyApp {
} else if (webmInfo.prompt) {
this.loadApiJson(webmInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'video/mp4' ||
@@ -1366,7 +1378,7 @@ export class ComfyApp {
} else if (svgInfo.prompt) {
this.loadApiJson(svgInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'model/gltf-binary' ||
@@ -1378,7 +1390,7 @@ export class ComfyApp {
} else if (gltfInfo.prompt) {
this.loadApiJson(gltfInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'application/json' ||
@@ -1409,7 +1421,7 @@ export class ComfyApp {
const info = await getLatentMetadata(file)
// TODO define schema to LatentMetadata
// @ts-expect-error
if (info.workflow) {
if (info?.workflow) {
await this.loadGraphData(
// @ts-expect-error
JSON.parse(info.workflow),
@@ -1418,14 +1430,14 @@ export class ComfyApp {
fileName
)
// @ts-expect-error
} else if (info.prompt) {
} else if (info?.prompt) {
// @ts-expect-error
this.loadApiJson(JSON.parse(info.prompt))
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
}

View File

@@ -143,12 +143,17 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer)
let header_size = dataView.getUint32(0, true)
let offset = 8
let header = JSON.parse(
new TextDecoder().decode(
safetensorsData.slice(offset, offset + header_size)
try {
let header = JSON.parse(
new TextDecoder().decode(
safetensorsData.slice(offset, offset + header_size)
)
)
)
r(header.__metadata__)
r(header.__metadata__)
} catch (e) {
// Invalid header
r(undefined)
}
}
var slice = file.slice(0, 1024 * 1024 * 4)

View File

@@ -2,6 +2,7 @@ import ApiNodesNewsContent from '@/components/dialog/content/ApiNodesNewsContent
import ApiNodesSignInContent from '@/components/dialog/content/ApiNodesSignInContent.vue'
import ConfirmationDialogContent from '@/components/dialog/content/ConfirmationDialogContent.vue'
import ErrorDialogContent from '@/components/dialog/content/ErrorDialogContent.vue'
import ImportModelDialogContent from '@/components/dialog/content/ImportModelDialogContent.vue'
import IssueReportDialogContent from '@/components/dialog/content/IssueReportDialogContent.vue'
import LoadWorkflowWarning from '@/components/dialog/content/LoadWorkflowWarning.vue'
import ManagerProgressDialogContent from '@/components/dialog/content/ManagerProgressDialogContent.vue'
@@ -406,6 +407,16 @@ export const useDialogService = () => {
})
}
function showImportModelDialog(
props: InstanceType<typeof ImportModelDialogContent>['$props']
) {
dialogStore.showDialog({
key: 'global-import-model',
component: ImportModelDialogContent,
props
})
}
return {
showLoadWorkflowWarning,
showMissingModelsWarning,
@@ -422,6 +433,7 @@ export const useDialogService = () => {
showTopUpCreditsDialog,
showUpdatePasswordDialog,
showApiNodesNewsDialog,
showImportModelDialog,
prompt,
confirm
}

View File

@@ -0,0 +1,93 @@
export interface ModelSpec {
'modelspec.sai_model_spec': string
'modelspec.architecture': string
'modelspec.title': string
'modelspec.description': string
}
const architectureToType: Record<string, string> = {
'stable-diffusion-v1': 'checkpoints',
'stable-diffusion-xl-v1-base': 'checkpoints',
'Flux.1-schnell': 'checkpoints',
'Flux.1-dev': 'checkpoints',
'Flux.1-AE': 'vae'
}
interface SafetensorsHeader<TMetadata = Record<string, string> | ModelSpec> {
[k: string]: unknown
__metadata__?: TMetadata
}
export async function guessModelType(file: File): Promise<string | null> {
const header = await getHeader(file)
if (!header) return null
let suggestedType: string | null
if (isModelSpec(header)) {
suggestedType = guessFromModelSpec(header)
}
suggestedType ??= guessFromHeaderKeys(header)
return suggestedType
}
async function getHeader(file: File): Promise<SafetensorsHeader | null> {
try {
// 8 bytes: an unsigned little-endian 64-bit integer, containing the size of the header
// Slice the first 8 bytes so we don't read the whole file
const headerSizeBlob = file.slice(0, 8)
const headerSizeView = new DataView(await headerSizeBlob.arrayBuffer())
const headerSize = headerSizeView.getBigUint64(0, true)
if (
headerSize < 0 ||
headerSize > file.size ||
headerSize > Number.MAX_SAFE_INTEGER
) {
// Invalid header, probably not a safetensors file
console.log(`Invalid header size ${headerSize} for file '${file.name}'`)
return null
}
// N bytes: a JSON UTF-8 string representing the header.
const header = file.slice(8, Number(headerSize) + 8)
const content = await header.text()
return JSON.parse(content)
} catch (error) {
// Error reading the file, probably not a safetensors file
console.error(`Error reading safetensors header '${file.name}'`, error)
return null
}
}
function guessFromModelSpec(header: SafetensorsHeader<ModelSpec>) {
const architecture = header.__metadata__?.['modelspec.architecture']
if (!architecture) return null
let suggestedType = architectureToType[architecture]
if (!suggestedType) {
if (architecture?.endsWith('/lora')) {
suggestedType = 'loras'
}
}
return suggestedType
}
function guessFromHeaderKeys(header: SafetensorsHeader) {
let suggestedType: string | null = null
const keys = Object.keys(header)
if (keys.find((k) => k.startsWith('lora_unet_'))) {
suggestedType = 'loras'
} else if (keys.find((k) => k.startsWith('model.diffusion_model.'))) {
suggestedType = 'checkpoints'
}
return suggestedType
}
function isModelSpec(
header: SafetensorsHeader
): header is SafetensorsHeader<ModelSpec> {
return !!header.__metadata__?.['modelspec.sai_model_spec']
}