mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-08 00:50:05 +00:00
Model file import support for desktop
This commit is contained in:
16
src/App.vue
16
src/App.vue
@@ -16,8 +16,10 @@ import { computed, onMounted } from 'vue'
|
||||
|
||||
import GlobalDialog from '@/components/dialog/GlobalDialog.vue'
|
||||
import config from '@/config'
|
||||
import { api } from '@/scripts/api'
|
||||
import { useWorkspaceStore } from '@/stores/workspaceStore'
|
||||
|
||||
import { useDialogService } from './services/dialogService'
|
||||
import { electronAPI, isElectron } from './utils/envUtil'
|
||||
|
||||
const workspaceStore = useWorkspaceStore()
|
||||
@@ -46,6 +48,20 @@ onMounted(() => {
|
||||
|
||||
if (isElectron()) {
|
||||
document.addEventListener('contextmenu', showContextMenu)
|
||||
|
||||
// Handle file drops to import models via electron
|
||||
api.addEventListener('unhandledFileDrop', async (e) => {
|
||||
e.preventDefault() // Prevent unable to find workflow in file error
|
||||
|
||||
const filePath = await electronAPI()['getFilePath'](e.detail.file)
|
||||
|
||||
if (filePath) {
|
||||
useDialogService().showImportModelDialog({
|
||||
path: filePath,
|
||||
file: e.detail.file
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
123
src/components/dialog/content/ImportModelDialogContent.vue
Normal file
123
src/components/dialog/content/ImportModelDialogContent.vue
Normal file
@@ -0,0 +1,123 @@
|
||||
<template>
|
||||
<div class="px-4 py-2 h-full gap-2 flex flex-col">
|
||||
<h2 class="text-4xl font-normal my-0">
|
||||
{{ t('importModelDialog.title') }}
|
||||
</h2>
|
||||
<span class="text-muted">{{ path }}</span>
|
||||
<div class="flex flex-col gap-2 mt-4">
|
||||
<IftaLabel>
|
||||
<Select
|
||||
v-model="selectedType"
|
||||
:options="modelFolders"
|
||||
editable
|
||||
filter
|
||||
labelId="model-type"
|
||||
:disabled="importing"
|
||||
/>
|
||||
<label for="model-type">Type</label>
|
||||
</IftaLabel>
|
||||
</div>
|
||||
<Message severity="error" v-if="importError">{{ importError }}</Message>
|
||||
</div>
|
||||
<footer>
|
||||
<div class="flex justify-between gap-2 p-4">
|
||||
<SelectButton
|
||||
v-model="selectedImportMode"
|
||||
optionLabel="label"
|
||||
optionValue="value"
|
||||
:options="importModes"
|
||||
:disabled="importing"
|
||||
/>
|
||||
<div class="flex gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
label="Cancel"
|
||||
severity="secondary"
|
||||
@click="dialogStore.closeDialog()"
|
||||
:disabled="importing"
|
||||
></Button>
|
||||
<Button
|
||||
type="button"
|
||||
label="Import"
|
||||
@click="importModel()"
|
||||
:icon="importIcon"
|
||||
:loading="importing"
|
||||
:disabled="!selectedType"
|
||||
></Button>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import Button from 'primevue/button'
|
||||
import IftaLabel from 'primevue/iftalabel'
|
||||
import Message from 'primevue/message'
|
||||
import Select from 'primevue/select'
|
||||
import SelectButton from 'primevue/selectbutton'
|
||||
import { computed, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useModelStore } from '@/stores/modelStore'
|
||||
import { electronAPI } from '@/utils/envUtil'
|
||||
import { guessModelType } from '@/utils/safetensorsUtil'
|
||||
|
||||
const { t } = useI18n()
|
||||
const dialogStore = useDialogStore()
|
||||
const { path, file } = defineProps<{
|
||||
path: string
|
||||
file: File
|
||||
}>()
|
||||
|
||||
const importModes = ref([
|
||||
{ label: t('importModelDialog.move'), value: 'move' },
|
||||
{ label: t('importModelDialog.copy'), value: 'copy' }
|
||||
])
|
||||
const modelStore = useModelStore()
|
||||
const modelFolders = ref<string[]>()
|
||||
const selectedType = ref<string>()
|
||||
const selectedImportMode = ref<string>('move')
|
||||
const importing = ref<boolean>(false)
|
||||
const importError = ref<string>()
|
||||
|
||||
const importIcon = computed(() => {
|
||||
return selectedImportMode.value === 'move'
|
||||
? 'pi pi-file-import'
|
||||
: 'pi pi-copy'
|
||||
})
|
||||
|
||||
const importModel = async () => {
|
||||
importing.value = true
|
||||
try {
|
||||
await electronAPI()?.['importModel'](
|
||||
file,
|
||||
selectedType.value,
|
||||
selectedImportMode.value
|
||||
)
|
||||
await useCommandStore().execute('Comfy.RefreshNodeDefinitions')
|
||||
dialogStore.closeDialog()
|
||||
} catch (error) {
|
||||
console.error(error)
|
||||
importError.value = error.message
|
||||
} finally {
|
||||
importing.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const init = async () => {
|
||||
if (!modelStore.modelFolders.length) {
|
||||
await modelStore.loadModelFolders()
|
||||
}
|
||||
|
||||
modelFolders.value = modelStore.modelFolders.map((folder) => folder.directory)
|
||||
const type = await guessModelType(file)
|
||||
|
||||
if (!selectedType.value) {
|
||||
selectedType.value = type
|
||||
}
|
||||
}
|
||||
|
||||
init()
|
||||
</script>
|
||||
@@ -1405,5 +1405,11 @@
|
||||
"tooltip": "Execute to selected output nodes (Highlighted with orange border)",
|
||||
"disabledTooltip": "No output nodes selected"
|
||||
}
|
||||
},
|
||||
"importModelDialog": {
|
||||
"title": "Import Model",
|
||||
"type": "Type",
|
||||
"move": "Move",
|
||||
"copy": "Copy"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ interface QueuePromptRequestBody {
|
||||
interface FrontendApiCalls {
|
||||
graphChanged: ComfyWorkflowJSON
|
||||
promptQueued: { number: number; batchCount: number }
|
||||
unhandledFileDrop: { file: File }
|
||||
graphCleared: never
|
||||
reconnecting: never
|
||||
reconnected: never
|
||||
@@ -313,20 +314,23 @@ export class ComfyApi extends EventTarget {
|
||||
* Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6).
|
||||
* @param type The type of event to emit
|
||||
* @param detail The detail property used for a custom event ({@link CustomEventInit.detail})
|
||||
* @param init The event config used for a custom event ({@link CustomEventInit})
|
||||
*/
|
||||
dispatchCustomEvent<T extends SimpleApiEvents>(type: T): boolean
|
||||
dispatchCustomEvent<T extends ComplexApiEvents>(
|
||||
type: T,
|
||||
detail: ApiEventTypes[T] | null
|
||||
detail: ApiEventTypes[T] | null,
|
||||
init?: EventInit
|
||||
): boolean
|
||||
dispatchCustomEvent<T extends keyof ApiEventTypes>(
|
||||
type: T,
|
||||
detail?: ApiEventTypes[T]
|
||||
detail?: ApiEventTypes[T],
|
||||
init?: EventInit
|
||||
): boolean {
|
||||
const event =
|
||||
detail === undefined
|
||||
? new CustomEvent(type)
|
||||
: new CustomEvent(type, { detail })
|
||||
? new CustomEvent(type, { ...init })
|
||||
: new CustomEvent(type, { detail, ...init })
|
||||
return super.dispatchEvent(event)
|
||||
}
|
||||
|
||||
|
||||
@@ -1252,10 +1252,22 @@ export class ComfyApp {
|
||||
return !executionStore.lastNodeErrors
|
||||
}
|
||||
|
||||
showErrorOnFileLoad(file: File) {
|
||||
useToastStore().addAlert(
|
||||
t('toastMessages.fileLoadError', { fileName: file.name })
|
||||
onUnhandledFile(file: File) {
|
||||
// Fire custom event to allow other parts of the app to handle the file
|
||||
const unhandled = api.dispatchCustomEvent(
|
||||
'unhandledFileDrop',
|
||||
{ file },
|
||||
{
|
||||
cancelable: true
|
||||
}
|
||||
)
|
||||
|
||||
if (unhandled) {
|
||||
// Nothing handled the event, so show the error dialog
|
||||
useToastStore().addAlert(
|
||||
t('toastMessages.fileLoadError', { fileName: file.name })
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1291,7 +1303,7 @@ export class ComfyApp {
|
||||
this.graph.serialize() as unknown as ComfyWorkflowJSON
|
||||
)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (file.type === 'image/webp') {
|
||||
const pngInfo = await getWebpMetadata(file)
|
||||
@@ -1304,7 +1316,7 @@ export class ComfyApp {
|
||||
} else if (prompt) {
|
||||
this.loadApiJson(JSON.parse(prompt), fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (file.type === 'audio/mpeg') {
|
||||
const { workflow, prompt } = await getMp3Metadata(file)
|
||||
@@ -1313,7 +1325,7 @@ export class ComfyApp {
|
||||
} else if (prompt) {
|
||||
this.loadApiJson(prompt, fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (file.type === 'audio/ogg') {
|
||||
const { workflow, prompt } = await getOggMetadata(file)
|
||||
@@ -1322,7 +1334,7 @@ export class ComfyApp {
|
||||
} else if (prompt) {
|
||||
this.loadApiJson(prompt, fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (file.type === 'audio/flac' || file.type === 'audio/x-flac') {
|
||||
const pngInfo = await getFlacMetadata(file)
|
||||
@@ -1334,7 +1346,7 @@ export class ComfyApp {
|
||||
} else if (prompt) {
|
||||
this.loadApiJson(JSON.parse(prompt), fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (file.type === 'video/webm') {
|
||||
const webmInfo = await getFromWebmFile(file)
|
||||
@@ -1343,7 +1355,7 @@ export class ComfyApp {
|
||||
} else if (webmInfo.prompt) {
|
||||
this.loadApiJson(webmInfo.prompt, fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (
|
||||
file.type === 'video/mp4' ||
|
||||
@@ -1366,7 +1378,7 @@ export class ComfyApp {
|
||||
} else if (svgInfo.prompt) {
|
||||
this.loadApiJson(svgInfo.prompt, fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (
|
||||
file.type === 'model/gltf-binary' ||
|
||||
@@ -1378,7 +1390,7 @@ export class ComfyApp {
|
||||
} else if (gltfInfo.prompt) {
|
||||
this.loadApiJson(gltfInfo.prompt, fileName)
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else if (
|
||||
file.type === 'application/json' ||
|
||||
@@ -1409,7 +1421,7 @@ export class ComfyApp {
|
||||
const info = await getLatentMetadata(file)
|
||||
// TODO define schema to LatentMetadata
|
||||
// @ts-expect-error
|
||||
if (info.workflow) {
|
||||
if (info?.workflow) {
|
||||
await this.loadGraphData(
|
||||
// @ts-expect-error
|
||||
JSON.parse(info.workflow),
|
||||
@@ -1418,14 +1430,14 @@ export class ComfyApp {
|
||||
fileName
|
||||
)
|
||||
// @ts-expect-error
|
||||
} else if (info.prompt) {
|
||||
} else if (info?.prompt) {
|
||||
// @ts-expect-error
|
||||
this.loadApiJson(JSON.parse(info.prompt))
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file)
|
||||
this.onUnhandledFile(file)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -143,12 +143,17 @@ export function getLatentMetadata(file) {
|
||||
const dataView = new DataView(safetensorsData.buffer)
|
||||
let header_size = dataView.getUint32(0, true)
|
||||
let offset = 8
|
||||
let header = JSON.parse(
|
||||
new TextDecoder().decode(
|
||||
safetensorsData.slice(offset, offset + header_size)
|
||||
try {
|
||||
let header = JSON.parse(
|
||||
new TextDecoder().decode(
|
||||
safetensorsData.slice(offset, offset + header_size)
|
||||
)
|
||||
)
|
||||
)
|
||||
r(header.__metadata__)
|
||||
r(header.__metadata__)
|
||||
} catch (e) {
|
||||
// Invalid header
|
||||
r(undefined)
|
||||
}
|
||||
}
|
||||
|
||||
var slice = file.slice(0, 1024 * 1024 * 4)
|
||||
|
||||
@@ -2,6 +2,7 @@ import ApiNodesNewsContent from '@/components/dialog/content/ApiNodesNewsContent
|
||||
import ApiNodesSignInContent from '@/components/dialog/content/ApiNodesSignInContent.vue'
|
||||
import ConfirmationDialogContent from '@/components/dialog/content/ConfirmationDialogContent.vue'
|
||||
import ErrorDialogContent from '@/components/dialog/content/ErrorDialogContent.vue'
|
||||
import ImportModelDialogContent from '@/components/dialog/content/ImportModelDialogContent.vue'
|
||||
import IssueReportDialogContent from '@/components/dialog/content/IssueReportDialogContent.vue'
|
||||
import LoadWorkflowWarning from '@/components/dialog/content/LoadWorkflowWarning.vue'
|
||||
import ManagerProgressDialogContent from '@/components/dialog/content/ManagerProgressDialogContent.vue'
|
||||
@@ -406,6 +407,16 @@ export const useDialogService = () => {
|
||||
})
|
||||
}
|
||||
|
||||
function showImportModelDialog(
|
||||
props: InstanceType<typeof ImportModelDialogContent>['$props']
|
||||
) {
|
||||
dialogStore.showDialog({
|
||||
key: 'global-import-model',
|
||||
component: ImportModelDialogContent,
|
||||
props
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
showLoadWorkflowWarning,
|
||||
showMissingModelsWarning,
|
||||
@@ -422,6 +433,7 @@ export const useDialogService = () => {
|
||||
showTopUpCreditsDialog,
|
||||
showUpdatePasswordDialog,
|
||||
showApiNodesNewsDialog,
|
||||
showImportModelDialog,
|
||||
prompt,
|
||||
confirm
|
||||
}
|
||||
|
||||
93
src/utils/safetensorsUtil.ts
Normal file
93
src/utils/safetensorsUtil.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
export interface ModelSpec {
|
||||
'modelspec.sai_model_spec': string
|
||||
'modelspec.architecture': string
|
||||
'modelspec.title': string
|
||||
'modelspec.description': string
|
||||
}
|
||||
|
||||
const architectureToType: Record<string, string> = {
|
||||
'stable-diffusion-v1': 'checkpoints',
|
||||
'stable-diffusion-xl-v1-base': 'checkpoints',
|
||||
'Flux.1-schnell': 'checkpoints',
|
||||
'Flux.1-dev': 'checkpoints',
|
||||
'Flux.1-AE': 'vae'
|
||||
}
|
||||
|
||||
interface SafetensorsHeader<TMetadata = Record<string, string> | ModelSpec> {
|
||||
[k: string]: unknown
|
||||
__metadata__?: TMetadata
|
||||
}
|
||||
|
||||
export async function guessModelType(file: File): Promise<string | null> {
|
||||
const header = await getHeader(file)
|
||||
if (!header) return null
|
||||
|
||||
let suggestedType: string | null
|
||||
if (isModelSpec(header)) {
|
||||
suggestedType = guessFromModelSpec(header)
|
||||
}
|
||||
|
||||
suggestedType ??= guessFromHeaderKeys(header)
|
||||
|
||||
return suggestedType
|
||||
}
|
||||
|
||||
async function getHeader(file: File): Promise<SafetensorsHeader | null> {
|
||||
try {
|
||||
// 8 bytes: an unsigned little-endian 64-bit integer, containing the size of the header
|
||||
// Slice the first 8 bytes so we don't read the whole file
|
||||
const headerSizeBlob = file.slice(0, 8)
|
||||
const headerSizeView = new DataView(await headerSizeBlob.arrayBuffer())
|
||||
const headerSize = headerSizeView.getBigUint64(0, true)
|
||||
|
||||
if (
|
||||
headerSize < 0 ||
|
||||
headerSize > file.size ||
|
||||
headerSize > Number.MAX_SAFE_INTEGER
|
||||
) {
|
||||
// Invalid header, probably not a safetensors file
|
||||
console.log(`Invalid header size ${headerSize} for file '${file.name}'`)
|
||||
return null
|
||||
}
|
||||
|
||||
// N bytes: a JSON UTF-8 string representing the header.
|
||||
const header = file.slice(8, Number(headerSize) + 8)
|
||||
const content = await header.text()
|
||||
return JSON.parse(content)
|
||||
} catch (error) {
|
||||
// Error reading the file, probably not a safetensors file
|
||||
console.error(`Error reading safetensors header '${file.name}'`, error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function guessFromModelSpec(header: SafetensorsHeader<ModelSpec>) {
|
||||
const architecture = header.__metadata__?.['modelspec.architecture']
|
||||
if (!architecture) return null
|
||||
let suggestedType = architectureToType[architecture]
|
||||
if (!suggestedType) {
|
||||
if (architecture?.endsWith('/lora')) {
|
||||
suggestedType = 'loras'
|
||||
}
|
||||
}
|
||||
|
||||
return suggestedType
|
||||
}
|
||||
|
||||
function guessFromHeaderKeys(header: SafetensorsHeader) {
|
||||
let suggestedType: string | null = null
|
||||
const keys = Object.keys(header)
|
||||
if (keys.find((k) => k.startsWith('lora_unet_'))) {
|
||||
suggestedType = 'loras'
|
||||
} else if (keys.find((k) => k.startsWith('model.diffusion_model.'))) {
|
||||
suggestedType = 'checkpoints'
|
||||
}
|
||||
|
||||
return suggestedType
|
||||
}
|
||||
|
||||
function isModelSpec(
|
||||
header: SafetensorsHeader
|
||||
): header is SafetensorsHeader<ModelSpec> {
|
||||
return !!header.__metadata__?.['modelspec.sai_model_spec']
|
||||
}
|
||||
Reference in New Issue
Block a user