Compare commits

...

3 Commits

Author SHA1 Message Date
pythongosssss
0d61221ad3 Add tests 2025-05-10 20:56:46 +01:00
github-actions
fec5dbcf70 Update locales [skip ci] 2025-05-10 20:04:19 +01:00
pythongosssss
c72ba664ee Model file import support for desktop 2025-05-10 20:04:19 +01:00
17 changed files with 643 additions and 25 deletions

View File

@@ -0,0 +1,124 @@
import { test as base } from '@playwright/test'
type ElectronFixtureOptions = {
registerDefaults?: {
downloadManager?: boolean
}
}
type MockFunction = {
calls: unknown[][]
called: () => Promise<void>
handler?: (args: unknown[]) => unknown
}
export type MockElectronAPI = {
setup: (method: string, handler: (args: unknown[]) => unknown) => MockFunction
}
export const electronFixture = base.extend<{
electronAPI: MockElectronAPI
electronOptions: ElectronFixtureOptions
}>({
electronOptions: [
{
registerDefaults: {
downloadManager: true
}
},
{ option: true }
],
electronAPI: [
async ({ page, electronOptions }, use) => {
const mocks = new Map<string, MockFunction>()
await page.exposeFunction(
'__handleMockCall',
async (method: string, args: unknown[]) => {
const mock = mocks.get(method)
if (electronOptions.registerDefaults?.downloadManager) {
if (method === 'DownloadManager.getAllDownloads') {
return []
}
}
if (!mock) return null
mock.calls.push(args)
return mock.handler ? mock.handler(args) : null
}
)
const createMockFunction = (
method: string,
handler: (args: unknown[]) => unknown
): MockFunction => {
let resolveNextCall: (() => void) | null = null
const mockFn: MockFunction = {
calls: [],
async called() {
if (this.calls.length > 0) return
return new Promise<void>((resolve) => {
resolveNextCall = resolve
})
},
handler: (args: unknown[]) => {
const result = handler(args)
resolveNextCall?.()
resolveNextCall = null
return result
}
}
mocks.set(method, mockFn)
// Add the method to the window.electronAPI object
page.evaluate((methodName) => {
const w = window as typeof window & {
electronAPI: Record<string, any>
}
w.electronAPI[methodName] = async (...args: unknown[]) => {
return window['__handleMockCall'](methodName, args)
}
}, method)
return mockFn
}
const testAPI: MockElectronAPI = {
setup(method, handler) {
console.log('adding handler for', method)
return createMockFunction(method, handler)
}
}
await page.addInitScript(async () => {
const getProxy = (...path: string[]) => {
return new Proxy(() => {}, {
// Handle the proxy itself being called as a function
apply: async (target, thisArg, argArray) => {
return window['__handleMockCall'](path.join('.'), argArray)
},
// Handle property access
get: (target, prop: string) => {
return getProxy(...path, prop)
}
})
}
const w = window as typeof window & {
electronAPI: any
}
w.electronAPI = getProxy()
console.log('registered electron api')
})
await use(testAPI)
},
{ auto: true }
]
})

View File

@@ -0,0 +1,172 @@
import { expect, mergeTests } from '@playwright/test'
import { ComfyPage, comfyPageFixture } from '../fixtures/ComfyPage'
import { MockElectronAPI, electronFixture } from './fixtures/electron'
const test = mergeTests(comfyPageFixture, electronFixture)
comfyPageFixture.describe('Import Model (web)', () => {
comfyPageFixture(
'Import dialog does not show when electron api is not available',
async ({ comfyPage }) => {
await comfyPage.dragAndDropExternalResource({
fileName: 'test.bin',
buffer: Buffer.from('')
})
// Normal unable to find workflow in file error
await expect(
comfyPage.page.locator('.p-toast-message.p-toast-message-warn')
).toHaveCount(1)
}
)
})
test.describe('Import Model (electron)', () => {
const dropFile = async (
comfyPage: ComfyPage,
electronAPI: MockElectronAPI,
fileName: string,
metadata: string
) => {
const getFilePathMock = electronAPI.setup('getFilePath', () =>
Promise.resolve('some/file/path/' + fileName)
)
let buffer: Buffer | undefined
if (metadata) {
const contentBuffer = Buffer.from(metadata, 'utf-8')
const headerSizeBuffer = Buffer.alloc(8)
headerSizeBuffer.writeBigUInt64LE(BigInt(contentBuffer.length))
buffer = Buffer.concat([headerSizeBuffer, contentBuffer])
}
await comfyPage.dragAndDropExternalResource({
fileName,
buffer
})
await getFilePathMock.called()
await expect(
comfyPage.page.locator('.p-toast-message.p-toast-message-warn')
).toHaveCount(0)
await expect(comfyPage.importModelDialog.rootEl).toBeVisible()
}
test('Can show import file dialog by dropping file onto the app', async ({
comfyPage,
electronAPI
}) => {
await dropFile(comfyPage, electronAPI, 'test.bin', '{}')
})
test('Can autodetect checkpoint model type from modelspec', async ({
comfyPage,
electronAPI
}) => {
await dropFile(
comfyPage,
electronAPI,
'file.safetensors',
JSON.stringify({
__metadata__: {
'modelspec.sai_model_spec': 'test',
'modelspec.architecture': 'stable-diffusion-v1'
}
})
)
await expect(comfyPage.importModelDialog.modelTypeInput).toHaveValue(
'checkpoints'
)
})
test('Can autodetect lora model type from modelspec', async ({
comfyPage,
electronAPI
}) => {
await dropFile(
comfyPage,
electronAPI,
'file.safetensors',
JSON.stringify({
__metadata__: {
'modelspec.sai_model_spec': 'test',
'modelspec.architecture': 'Flux.1-AE/lora'
}
})
)
await expect(comfyPage.importModelDialog.modelTypeInput).toHaveValue(
'loras'
)
})
test('Can autodetect checkpoint model type from header keys', async ({
comfyPage,
electronAPI
}) => {
await dropFile(
comfyPage,
electronAPI,
'file.safetensors',
JSON.stringify({
'model.diffusion_model.input_blocks.0.0.bias': {}
})
)
await expect(comfyPage.importModelDialog.modelTypeInput).toHaveValue(
'checkpoints'
)
})
test('Can autodetect lora model type from header keys', async ({
comfyPage,
electronAPI
}) => {
await dropFile(
comfyPage,
electronAPI,
'file.safetensors',
JSON.stringify({
'lora_unet_down_blocks_0_attentions_0_proj_in.alpha': {}
})
)
await expect(comfyPage.importModelDialog.modelTypeInput).toHaveValue(
'loras'
)
})
test('Can import file', async ({ comfyPage, electronAPI }) => {
await dropFile(
comfyPage,
electronAPI,
'checkpoint_modelspec.safetensors',
'{}'
)
const importModelMock = electronAPI.setup(
'importModel',
() => new Promise((resolve) => setTimeout(resolve, 100))
)
// Model type is required so select one
await expect(comfyPage.importModelDialog.importButton).toBeDisabled()
await comfyPage.importModelDialog.modelTypeInput.fill('checkpoints')
await expect(comfyPage.importModelDialog.importButton).toBeEnabled()
// Click import, ensure API is called
await comfyPage.importModelDialog.importButton.click()
await importModelMock.called()
// Toast should be shown and dialog closes
await expect(
comfyPage.page.locator('.p-toast-message.p-toast-message-success')
).toHaveCount(1)
await expect(comfyPage.importModelDialog.rootEl).toBeHidden()
})
})

View File

@@ -13,6 +13,7 @@ import { ComfyActionbar } from '../helpers/actionbar'
import { ComfyTemplates } from '../helpers/templates'
import { ComfyMouse } from './ComfyMouse'
import { ComfyNodeSearchBox } from './components/ComfyNodeSearchBox'
import { ImportModelDialog } from './components/ImportModelDialog'
import { SettingDialog } from './components/SettingDialog'
import {
NodeLibrarySidebarTab,
@@ -140,6 +141,7 @@ export class ComfyPage {
public readonly templates: ComfyTemplates
public readonly settingDialog: SettingDialog
public readonly confirmDialog: ConfirmDialog
public readonly importModelDialog: ImportModelDialog
/** Worker index to test user ID */
public readonly userIds: string[] = []
@@ -165,6 +167,7 @@ export class ComfyPage {
this.templates = new ComfyTemplates(page)
this.settingDialog = new SettingDialog(page)
this.confirmDialog = new ConfirmDialog(page)
this.importModelDialog = new ImportModelDialog(page)
}
convertLeafToContent(structure: FolderStructure): FolderStructure {
@@ -469,6 +472,7 @@ export class ComfyPage {
fileName?: string
url?: string
dropPosition?: Position
buffer?: Buffer
} = {}
) {
const { dropPosition = { x: 100, y: 100 }, fileName, url } = options
@@ -487,7 +491,7 @@ export class ComfyPage {
// Dropping a file from the filesystem
if (fileName) {
const filePath = this.assetPath(fileName)
const buffer = fs.readFileSync(filePath)
const buffer = options.buffer ?? fs.readFileSync(filePath)
const getFileType = (fileName: string) => {
if (fileName.endsWith('.png')) return 'image/png'

View File

@@ -0,0 +1,17 @@
import { Page } from '@playwright/test'
export class ImportModelDialog {
constructor(public readonly page: Page) {}
get rootEl() {
return this.page.locator('div[aria-labelledby="global-import-model"]')
}
get modelTypeInput() {
return this.rootEl.locator('#model-type')
}
get importButton() {
return this.rootEl.getByLabel('Import')
}
}

View File

@@ -16,8 +16,10 @@ import { computed, onMounted } from 'vue'
import GlobalDialog from '@/components/dialog/GlobalDialog.vue'
import config from '@/config'
import { api } from '@/scripts/api'
import { useWorkspaceStore } from '@/stores/workspaceStore'
import { useDialogService } from './services/dialogService'
import { electronAPI, isElectron } from './utils/envUtil'
const workspaceStore = useWorkspaceStore()
@@ -46,6 +48,20 @@ onMounted(() => {
if (isElectron()) {
document.addEventListener('contextmenu', showContextMenu)
// Handle file drops to import models via electron
api.addEventListener('unhandledFileDrop', async (e) => {
e.preventDefault() // Prevent unable to find workflow in file error
const filePath = await electronAPI()['getFilePath'](e.detail.file)
if (filePath) {
useDialogService().showImportModelDialog({
path: filePath,
file: e.detail.file
})
}
})
}
})
</script>

View File

@@ -0,0 +1,123 @@
<template>
<div class="px-4 py-2 h-full gap-2 flex flex-col">
<h2 class="text-4xl font-normal my-0">
{{ t('importModelDialog.title') }}
</h2>
<span class="text-muted">{{ path }}</span>
<div class="flex flex-col gap-2 mt-4">
<IftaLabel>
<Select
v-model="selectedType"
:options="modelFolders"
editable
filter
labelId="model-type"
:disabled="importing"
/>
<label for="model-type">Type</label>
</IftaLabel>
</div>
<Message severity="error" v-if="importError">{{ importError }}</Message>
</div>
<footer>
<div class="flex justify-between gap-2 p-4">
<SelectButton
v-model="selectedImportMode"
optionLabel="label"
optionValue="value"
:options="importModes"
:disabled="importing"
/>
<div class="flex gap-2">
<Button
type="button"
label="Cancel"
severity="secondary"
@click="dialogStore.closeDialog()"
:disabled="importing"
></Button>
<Button
type="button"
label="Import"
@click="importModel()"
:icon="importIcon"
:loading="importing"
:disabled="!selectedType"
></Button>
</div>
</div>
</footer>
</template>
<script setup lang="ts">
import Button from 'primevue/button'
import IftaLabel from 'primevue/iftalabel'
import Message from 'primevue/message'
import Select from 'primevue/select'
import SelectButton from 'primevue/selectbutton'
import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { useCommandStore } from '@/stores/commandStore'
import { useDialogStore } from '@/stores/dialogStore'
import { useModelStore } from '@/stores/modelStore'
import { electronAPI } from '@/utils/envUtil'
import { guessModelType } from '@/utils/safetensorsUtil'
const { t } = useI18n()
const dialogStore = useDialogStore()
const { path, file } = defineProps<{
path: string
file: File
}>()
const importModes = ref([
{ label: t('importModelDialog.move'), value: 'move' },
{ label: t('importModelDialog.copy'), value: 'copy' }
])
const modelStore = useModelStore()
const modelFolders = ref<string[]>()
const selectedType = ref<string>()
const selectedImportMode = ref<string>('move')
const importing = ref<boolean>(false)
const importError = ref<string>()
const importIcon = computed(() => {
return selectedImportMode.value === 'move'
? 'pi pi-file-import'
: 'pi pi-copy'
})
const importModel = async () => {
importing.value = true
try {
await electronAPI()?.['importModel'](
file,
selectedType.value,
selectedImportMode.value
)
await useCommandStore().execute('Comfy.RefreshNodeDefinitions')
dialogStore.closeDialog()
} catch (error) {
console.error(error)
importError.value = error.message
} finally {
importing.value = false
}
}
const init = async () => {
if (!modelStore.modelFolders.length) {
await modelStore.loadModelFolders()
}
modelFolders.value = modelStore.modelFolders.map((folder) => folder.directory)
const type = await guessModelType(file)
if (!selectedType.value) {
selectedType.value = type
}
}
init()
</script>

View File

@@ -1405,5 +1405,11 @@
"tooltip": "Execute to selected output nodes (Highlighted with orange border)",
"disabledTooltip": "No output nodes selected"
}
},
"importModelDialog": {
"title": "Import Model",
"type": "Type",
"move": "Move",
"copy": "Copy"
}
}

View File

@@ -396,6 +396,12 @@
"inbox": "Boîte de réception",
"star": "Étoile"
},
"importModelDialog": {
"copy": "Copier",
"move": "Déplacer",
"title": "Importer le modèle",
"type": "Type"
},
"install": {
"appDataLocationTooltip": "Répertoire des données de l'application ComfyUI. Stocke :\n- Logs\n- Configurations du serveur",
"appPathLocationTooltip": "Répertoire des ressources de l'application ComfyUI. Stocke le code et les ressources de ComfyUI",

View File

@@ -396,6 +396,12 @@
"inbox": "受信トレイ",
"star": "星"
},
"importModelDialog": {
"copy": "コピー",
"move": "移動",
"title": "モデルをインポート",
"type": "タイプ"
},
"install": {
"appDataLocationTooltip": "ComfyUIのアプリデータディレクトリ。保存内容:\n- ログ\n- サーバー設定",
"appPathLocationTooltip": "ComfyUIのアプリ資産ディレクトリ。ComfyUIのコードとアセットを保存します",

View File

@@ -396,6 +396,12 @@
"inbox": "받은 편지함",
"star": "별"
},
"importModelDialog": {
"copy": "복사",
"move": "이동",
"title": "모델 가져오기",
"type": "유형"
},
"install": {
"appDataLocationTooltip": "ComfyUI의 앱 데이터 디렉토리. 저장소:\n- 로그\n- 서버 구성",
"appPathLocationTooltip": "ComfyUI의 앱 에셋 디렉토리. ComfyUI 코드 및 에셋을 저장합니다.",

View File

@@ -396,6 +396,12 @@
"inbox": "Входящие",
"star": "Звезда"
},
"importModelDialog": {
"copy": "Копировать",
"move": "Переместить",
"title": "Импорт модели",
"type": "Тип"
},
"install": {
"appDataLocationTooltip": "Директория данных приложения ComfyUI. Хранит:\n- Логи\n- Конфигурации сервера",
"appPathLocationTooltip": "Директория активов приложения ComfyUI. Хранит код и активы ComfyUI",

View File

@@ -396,6 +396,12 @@
"inbox": "收件箱",
"star": "星星"
},
"importModelDialog": {
"copy": "复制",
"move": "移动",
"title": "导入模型",
"type": "类型"
},
"install": {
"appDataLocationTooltip": "ComfyUI 的应用数据目录。存储:\n- 日志\n- 服务器配置",
"appPathLocationTooltip": "ComfyUI 的应用资产目录。存储 ComfyUI 代码和资产",

View File

@@ -82,6 +82,7 @@ interface QueuePromptRequestBody {
interface FrontendApiCalls {
graphChanged: ComfyWorkflowJSON
promptQueued: { number: number; batchCount: number }
unhandledFileDrop: { file: File }
graphCleared: never
reconnecting: never
reconnected: never
@@ -313,20 +314,23 @@ export class ComfyApi extends EventTarget {
* Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6).
* @param type The type of event to emit
* @param detail The detail property used for a custom event ({@link CustomEventInit.detail})
* @param init The event config used for a custom event ({@link CustomEventInit})
*/
dispatchCustomEvent<T extends SimpleApiEvents>(type: T): boolean
dispatchCustomEvent<T extends ComplexApiEvents>(
type: T,
detail: ApiEventTypes[T] | null
detail: ApiEventTypes[T] | null,
init?: EventInit
): boolean
dispatchCustomEvent<T extends keyof ApiEventTypes>(
type: T,
detail?: ApiEventTypes[T]
detail?: ApiEventTypes[T],
init?: EventInit
): boolean {
const event =
detail === undefined
? new CustomEvent(type)
: new CustomEvent(type, { detail })
? new CustomEvent(type, { ...init })
: new CustomEvent(type, { detail, ...init })
return super.dispatchEvent(event)
}

View File

@@ -1252,10 +1252,22 @@ export class ComfyApp {
return !executionStore.lastNodeErrors
}
showErrorOnFileLoad(file: File) {
useToastStore().addAlert(
t('toastMessages.fileLoadError', { fileName: file.name })
onUnhandledFile(file: File) {
// Fire custom event to allow other parts of the app to handle the file
const unhandled = api.dispatchCustomEvent(
'unhandledFileDrop',
{ file },
{
cancelable: true
}
)
if (unhandled) {
// Nothing handled the event, so show the error dialog
useToastStore().addAlert(
t('toastMessages.fileLoadError', { fileName: file.name })
)
}
}
/**
@@ -1291,7 +1303,7 @@ export class ComfyApp {
this.graph.serialize() as unknown as ComfyWorkflowJSON
)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'image/webp') {
const pngInfo = await getWebpMetadata(file)
@@ -1304,7 +1316,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/mpeg') {
const { workflow, prompt } = await getMp3Metadata(file)
@@ -1313,7 +1325,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/ogg') {
const { workflow, prompt } = await getOggMetadata(file)
@@ -1322,7 +1334,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'audio/flac' || file.type === 'audio/x-flac') {
const pngInfo = await getFlacMetadata(file)
@@ -1334,7 +1346,7 @@ export class ComfyApp {
} else if (prompt) {
this.loadApiJson(JSON.parse(prompt), fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (file.type === 'video/webm') {
const webmInfo = await getFromWebmFile(file)
@@ -1343,7 +1355,7 @@ export class ComfyApp {
} else if (webmInfo.prompt) {
this.loadApiJson(webmInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'video/mp4' ||
@@ -1366,7 +1378,7 @@ export class ComfyApp {
} else if (svgInfo.prompt) {
this.loadApiJson(svgInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'model/gltf-binary' ||
@@ -1378,7 +1390,7 @@ export class ComfyApp {
} else if (gltfInfo.prompt) {
this.loadApiJson(gltfInfo.prompt, fileName)
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else if (
file.type === 'application/json' ||
@@ -1409,7 +1421,7 @@ export class ComfyApp {
const info = await getLatentMetadata(file)
// TODO define schema to LatentMetadata
// @ts-expect-error
if (info.workflow) {
if (info?.workflow) {
await this.loadGraphData(
// @ts-expect-error
JSON.parse(info.workflow),
@@ -1418,14 +1430,14 @@ export class ComfyApp {
fileName
)
// @ts-expect-error
} else if (info.prompt) {
} else if (info?.prompt) {
// @ts-expect-error
this.loadApiJson(JSON.parse(info.prompt))
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
} else {
this.showErrorOnFileLoad(file)
this.onUnhandledFile(file)
}
}

View File

@@ -143,12 +143,17 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer)
let header_size = dataView.getUint32(0, true)
let offset = 8
let header = JSON.parse(
new TextDecoder().decode(
safetensorsData.slice(offset, offset + header_size)
try {
let header = JSON.parse(
new TextDecoder().decode(
safetensorsData.slice(offset, offset + header_size)
)
)
)
r(header.__metadata__)
r(header.__metadata__)
} catch (e) {
// Invalid header
r(undefined)
}
}
var slice = file.slice(0, 1024 * 1024 * 4)

View File

@@ -2,6 +2,7 @@ import ApiNodesNewsContent from '@/components/dialog/content/ApiNodesNewsContent
import ApiNodesSignInContent from '@/components/dialog/content/ApiNodesSignInContent.vue'
import ConfirmationDialogContent from '@/components/dialog/content/ConfirmationDialogContent.vue'
import ErrorDialogContent from '@/components/dialog/content/ErrorDialogContent.vue'
import ImportModelDialogContent from '@/components/dialog/content/ImportModelDialogContent.vue'
import IssueReportDialogContent from '@/components/dialog/content/IssueReportDialogContent.vue'
import LoadWorkflowWarning from '@/components/dialog/content/LoadWorkflowWarning.vue'
import ManagerProgressDialogContent from '@/components/dialog/content/ManagerProgressDialogContent.vue'
@@ -406,6 +407,16 @@ export const useDialogService = () => {
})
}
function showImportModelDialog(
props: InstanceType<typeof ImportModelDialogContent>['$props']
) {
dialogStore.showDialog({
key: 'global-import-model',
component: ImportModelDialogContent,
props
})
}
return {
showLoadWorkflowWarning,
showMissingModelsWarning,
@@ -422,6 +433,7 @@ export const useDialogService = () => {
showTopUpCreditsDialog,
showUpdatePasswordDialog,
showApiNodesNewsDialog,
showImportModelDialog,
prompt,
confirm
}

View File

@@ -0,0 +1,93 @@
export interface ModelSpec {
'modelspec.sai_model_spec': string
'modelspec.architecture': string
'modelspec.title': string
'modelspec.description': string
}
const architectureToType: Record<string, string> = {
'stable-diffusion-v1': 'checkpoints',
'stable-diffusion-xl-v1-base': 'checkpoints',
'Flux.1-schnell': 'checkpoints',
'Flux.1-dev': 'checkpoints',
'Flux.1-AE': 'vae'
}
interface SafetensorsHeader<TMetadata = Record<string, string> | ModelSpec> {
[k: string]: unknown
__metadata__?: TMetadata
}
export async function guessModelType(file: File): Promise<string | null> {
const header = await getHeader(file)
if (!header) return null
let suggestedType: string | null
if (isModelSpec(header)) {
suggestedType = guessFromModelSpec(header)
}
suggestedType ??= guessFromHeaderKeys(header)
return suggestedType
}
async function getHeader(file: File): Promise<SafetensorsHeader | null> {
try {
// 8 bytes: an unsigned little-endian 64-bit integer, containing the size of the header
// Slice the first 8 bytes so we don't read the whole file
const headerSizeBlob = file.slice(0, 8)
const headerSizeView = new DataView(await headerSizeBlob.arrayBuffer())
const headerSize = headerSizeView.getBigUint64(0, true)
if (
headerSize < 0 ||
headerSize > file.size ||
headerSize > Number.MAX_SAFE_INTEGER
) {
// Invalid header, probably not a safetensors file
console.log(`Invalid header size ${headerSize} for file '${file.name}'`)
return null
}
// N bytes: a JSON UTF-8 string representing the header.
const header = file.slice(8, Number(headerSize) + 8)
const content = await header.text()
return JSON.parse(content)
} catch (error) {
// Error reading the file, probably not a safetensors file
console.error(`Error reading safetensors header '${file.name}'`, error)
return null
}
}
function guessFromModelSpec(header: SafetensorsHeader<ModelSpec>) {
const architecture = header.__metadata__?.['modelspec.architecture']
if (!architecture) return null
let suggestedType = architectureToType[architecture]
if (!suggestedType) {
if (architecture?.endsWith('/lora')) {
suggestedType = 'loras'
}
}
return suggestedType
}
function guessFromHeaderKeys(header: SafetensorsHeader) {
let suggestedType: string | null = null
const keys = Object.keys(header)
if (keys.find((k) => k.startsWith('lora_unet_'))) {
suggestedType = 'loras'
} else if (keys.find((k) => k.startsWith('model.diffusion_model.'))) {
suggestedType = 'checkpoints'
}
return suggestedType
}
function isModelSpec(
header: SafetensorsHeader
): header is SafetensorsHeader<ModelSpec> {
return !!header.__metadata__?.['modelspec.sai_model_spec']
}