mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-04-29 02:32:18 +00:00
initial model store (#674)
* initial model store * refactor the 'modelstoreserviceimpl' to pinia * pepper in some reactive the inner ModelStore (per-folder) can't be pinia because its made of temporary instances, but it can be reactive * use refs in metadata class * remove 'reactive' * remove ref too * add simple unit tests for modelStore * make things worse via autoformatting * move mock impls to a function
This commit is contained in:
committed by
GitHub
parent
6c7fb5041d
commit
060e61f0db
@@ -343,6 +343,19 @@ class ComfyApi extends EventTarget {
|
|||||||
return await res.json()
|
return await res.json()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the metadata for a model
|
||||||
|
* @param {string} folder The folder containing the model
|
||||||
|
* @param {string} model The model to get metadata for
|
||||||
|
* @returns The metadata for the model
|
||||||
|
*/
|
||||||
|
async viewMetadata(folder: string, model: string) {
|
||||||
|
const res = await this.fetchApi(
|
||||||
|
`/view_metadata/${folder}?filename=${encodeURIComponent(model)}`
|
||||||
|
)
|
||||||
|
return await res.json()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tells the server to download a model from the specified URL to the specified directory and filename
|
* Tells the server to download a model from the specified URL to the specified directory and filename
|
||||||
* @param {string} url The URL to download the model from
|
* @param {string} url The URL to download the model from
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ import {
|
|||||||
} from '@/services/dialogService'
|
} from '@/services/dialogService'
|
||||||
import { useSettingStore } from '@/stores/settingStore'
|
import { useSettingStore } from '@/stores/settingStore'
|
||||||
import { useToastStore } from '@/stores/toastStore'
|
import { useToastStore } from '@/stores/toastStore'
|
||||||
|
import { ModelStore, useModelStore } from '@/stores/modelStore'
|
||||||
import type { ToastMessageOptions } from 'primevue/toast'
|
import type { ToastMessageOptions } from 'primevue/toast'
|
||||||
import { useWorkspaceStore } from '@/stores/workspaceStateStore'
|
import { useWorkspaceStore } from '@/stores/workspaceStateStore'
|
||||||
import { LGraphGroup } from '@comfyorg/litegraph'
|
import { LGraphGroup } from '@comfyorg/litegraph'
|
||||||
@@ -137,7 +138,6 @@ export class ComfyApp {
|
|||||||
bodyBottom: HTMLElement
|
bodyBottom: HTMLElement
|
||||||
canvasContainer: HTMLElement
|
canvasContainer: HTMLElement
|
||||||
menu: ComfyAppMenu
|
menu: ComfyAppMenu
|
||||||
modelsInFolderCache: Record<string, string[]>
|
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.vueAppReady = false
|
this.vueAppReady = false
|
||||||
@@ -152,7 +152,6 @@ export class ComfyApp {
|
|||||||
parent: document.body
|
parent: document.body
|
||||||
})
|
})
|
||||||
this.menu = new ComfyAppMenu(this)
|
this.menu = new ComfyAppMenu(this)
|
||||||
this.modelsInFolderCache = {}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of extensions that are registered with the app
|
* List of extensions that are registered with the app
|
||||||
@@ -2260,12 +2259,14 @@ export class ComfyApp {
|
|||||||
useSettingStore().get('Comfy.Workflow.ShowMissingModelsWarning')
|
useSettingStore().get('Comfy.Workflow.ShowMissingModelsWarning')
|
||||||
) {
|
) {
|
||||||
for (let m of graphData.models) {
|
for (let m of graphData.models) {
|
||||||
const models_available = await this.getModelsInFolderCached(m.directory)
|
const models_available = await useModelStore().getModelsInFolderCached(
|
||||||
|
m.directory
|
||||||
|
)
|
||||||
if (models_available === null) {
|
if (models_available === null) {
|
||||||
// @ts-expect-error
|
// @ts-expect-error
|
||||||
m.directory_invalid = true
|
m.directory_invalid = true
|
||||||
missingModels.push(m)
|
missingModels.push(m)
|
||||||
} else if (!models_available.includes(m.name)) {
|
} else if (!(m.name in models_available.models)) {
|
||||||
missingModels.push(m)
|
missingModels.push(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2860,19 +2861,6 @@ export class ComfyApp {
|
|||||||
app.graph.arrange()
|
app.graph.arrange()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the list of model names in a folder, using a temporary local cache
|
|
||||||
*/
|
|
||||||
async getModelsInFolderCached(folder: string): Promise<string[]> {
|
|
||||||
if (folder in this.modelsInFolderCache) {
|
|
||||||
return this.modelsInFolderCache[folder]
|
|
||||||
}
|
|
||||||
// TODO: needs a lock to avoid overlapping calls
|
|
||||||
const models = await api.getModels(folder)
|
|
||||||
this.modelsInFolderCache[folder] = models
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Registers a Comfy web extension with the app
|
* Registers a Comfy web extension with the app
|
||||||
* @param {ComfyExtension} extension
|
* @param {ComfyExtension} extension
|
||||||
@@ -2901,9 +2889,10 @@ export class ComfyApp {
|
|||||||
summary: 'Update',
|
summary: 'Update',
|
||||||
detail: 'Update requested'
|
detail: 'Update requested'
|
||||||
}
|
}
|
||||||
if (this.vueAppReady) useToastStore().add(requestToastMessage)
|
if (this.vueAppReady) {
|
||||||
|
useToastStore().add(requestToastMessage)
|
||||||
this.modelsInFolderCache = {}
|
useModelStore().clearCache()
|
||||||
|
}
|
||||||
|
|
||||||
const defs = await api.getNodeDefs()
|
const defs = await api.getNodeDefs()
|
||||||
|
|
||||||
|
|||||||
136
src/stores/modelStore.ts
Normal file
136
src/stores/modelStore.ts
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import { api } from '@/scripts/api'
|
||||||
|
import { defineStore } from 'pinia'
|
||||||
|
|
||||||
|
/** (Internal helper) finds a value in a metadata object from any of a list of keys. */
|
||||||
|
function _findInMetadata(metadata: any, ...keys: string[]): string | null {
|
||||||
|
for (const key of keys) {
|
||||||
|
if (key in metadata) {
|
||||||
|
return metadata[key]
|
||||||
|
}
|
||||||
|
for (const k in metadata) {
|
||||||
|
if (k.endsWith(key)) {
|
||||||
|
return metadata[k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Defines and holds metadata for a model */
|
||||||
|
export class ComfyModelDef {
|
||||||
|
/** Proper filename of the model */
|
||||||
|
name: string = ''
|
||||||
|
/** Directory containing the model, eg 'checkpoints' */
|
||||||
|
directory: string = ''
|
||||||
|
/** Title / display name of the model, sometimes same as the name but not always */
|
||||||
|
title: string = ''
|
||||||
|
/** Metadata: architecture ID for the model, such as 'stable-diffusion-xl-v1-base' */
|
||||||
|
architecture_id: string = ''
|
||||||
|
/** Metadata: author of the model */
|
||||||
|
author: string = ''
|
||||||
|
/** Metadata: resolution of the model, eg '1024x1024' */
|
||||||
|
resolution: string = ''
|
||||||
|
/** Metadata: description of the model */
|
||||||
|
description: string = ''
|
||||||
|
/** Metadata: usage hint for the model */
|
||||||
|
usage_hint: string = ''
|
||||||
|
/** Metadata: trigger phrase for the model */
|
||||||
|
trigger_phrase: string = ''
|
||||||
|
/** Metadata: tags list for the model */
|
||||||
|
tags: string[] = []
|
||||||
|
/** Metadata: image for the model */
|
||||||
|
image: string = ''
|
||||||
|
/** Whether the model metadata has been loaded from the server, used for `load()` */
|
||||||
|
has_loaded_metadata: boolean = false
|
||||||
|
|
||||||
|
constructor(name: string, directory: string) {
|
||||||
|
this.name = name
|
||||||
|
this.title = name
|
||||||
|
this.directory = directory
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads the model metadata from the server, filling in this object if data is available */
|
||||||
|
async load(): Promise<void> {
|
||||||
|
if (this.has_loaded_metadata) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const metadata = await api.viewMetadata(this.directory, this.name)
|
||||||
|
if (!metadata) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.title =
|
||||||
|
_findInMetadata(
|
||||||
|
metadata,
|
||||||
|
'modelspec.title',
|
||||||
|
'title',
|
||||||
|
'display_name',
|
||||||
|
'name'
|
||||||
|
) || this.name
|
||||||
|
this.architecture_id =
|
||||||
|
_findInMetadata(metadata, 'modelspec.architecture', 'architecture') || ''
|
||||||
|
this.author = _findInMetadata(metadata, 'modelspec.author', 'author') || ''
|
||||||
|
this.description =
|
||||||
|
_findInMetadata(metadata, 'modelspec.description', 'description') || ''
|
||||||
|
this.resolution =
|
||||||
|
_findInMetadata(metadata, 'modelspec.resolution', 'resolution') || ''
|
||||||
|
this.usage_hint =
|
||||||
|
_findInMetadata(metadata, 'modelspec.usage_hint', 'usage_hint') || ''
|
||||||
|
this.trigger_phrase =
|
||||||
|
_findInMetadata(metadata, 'modelspec.trigger_phrase', 'trigger_phrase') ||
|
||||||
|
''
|
||||||
|
this.image =
|
||||||
|
_findInMetadata(
|
||||||
|
metadata,
|
||||||
|
'modelspec.thumbnail',
|
||||||
|
'thumbnail',
|
||||||
|
'image',
|
||||||
|
'icon'
|
||||||
|
) || ''
|
||||||
|
const tagsCommaSeparated =
|
||||||
|
_findInMetadata(metadata, 'modelspec.tags', 'tags') || ''
|
||||||
|
this.tags = tagsCommaSeparated.split(',').map((tag) => tag.trim())
|
||||||
|
this.has_loaded_metadata = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Model store for a folder */
|
||||||
|
export class ModelStore {
|
||||||
|
models: Record<string, ComfyModelDef> = {}
|
||||||
|
|
||||||
|
constructor(directory: string, models: string[]) {
|
||||||
|
for (const model of models) {
|
||||||
|
this.models[model] = new ComfyModelDef(model, directory)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async loadModelMetadata(modelName: string) {
|
||||||
|
if (this.models[modelName]) {
|
||||||
|
await this.models[modelName].load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Model store handler, wraps individual per-folder model stores */
|
||||||
|
export const useModelStore = defineStore('modelStore', {
|
||||||
|
state: () => ({
|
||||||
|
modelStoreMap: {} as Record<string, ModelStore>
|
||||||
|
}),
|
||||||
|
actions: {
|
||||||
|
async getModelsInFolderCached(folder: string): Promise<ModelStore> {
|
||||||
|
if (folder in this.modelStoreMap) {
|
||||||
|
return this.modelStoreMap[folder]
|
||||||
|
}
|
||||||
|
// TODO: needs a lock to avoid overlapping calls
|
||||||
|
const models = await api.getModels(folder)
|
||||||
|
if (!models) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
const store = new ModelStore(folder, models)
|
||||||
|
this.modelStoreMap[folder] = store
|
||||||
|
return store
|
||||||
|
},
|
||||||
|
clearCache() {
|
||||||
|
this.modelStoreMap = {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
85
tests-ui/tests/store/modelStore.test.ts
Normal file
85
tests-ui/tests/store/modelStore.test.ts
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import { setActivePinia, createPinia } from 'pinia'
|
||||||
|
import { useModelStore } from '@/stores/modelStore'
|
||||||
|
import { api } from '@/scripts/api'
|
||||||
|
|
||||||
|
// Mock the api
|
||||||
|
jest.mock('@/scripts/api', () => ({
|
||||||
|
api: {
|
||||||
|
getModels: jest.fn(),
|
||||||
|
viewMetadata: jest.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
function enableMocks() {
|
||||||
|
;(api.getModels as jest.Mock).mockResolvedValue([
|
||||||
|
'sdxl.safetensors',
|
||||||
|
'sdv15.safetensors',
|
||||||
|
'noinfo.safetensors'
|
||||||
|
])
|
||||||
|
;(api.viewMetadata as jest.Mock).mockImplementation((_, model) => {
|
||||||
|
if (model === 'noinfo.safetensors') {
|
||||||
|
return Promise.resolve({})
|
||||||
|
}
|
||||||
|
return Promise.resolve({
|
||||||
|
'modelspec.title': `Title of ${model}`,
|
||||||
|
display_name: 'Should not show',
|
||||||
|
'modelspec.architecture': 'stable-diffusion-xl-base-v1',
|
||||||
|
'modelspec.author': `Author of ${model}`,
|
||||||
|
'modelspec.description': `Description of ${model}`,
|
||||||
|
'modelspec.resolution': '1024x1024',
|
||||||
|
trigger_phrase: `Trigger phrase of ${model}`,
|
||||||
|
usage_hint: `Usage hint of ${model}`,
|
||||||
|
tags: `tags,for,${model}`
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('useModelStore', () => {
|
||||||
|
let store: ReturnType<typeof useModelStore>
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
setActivePinia(createPinia())
|
||||||
|
store = useModelStore()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should load models', async () => {
|
||||||
|
enableMocks()
|
||||||
|
const folderStore = await store.getModelsInFolderCached('checkpoints')
|
||||||
|
expect(folderStore).not.toBeNull()
|
||||||
|
expect(Object.keys(folderStore.models)).toHaveLength(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should load model metadata', async () => {
|
||||||
|
enableMocks()
|
||||||
|
const folderStore = await store.getModelsInFolderCached('checkpoints')
|
||||||
|
const model = folderStore.models['sdxl.safetensors']
|
||||||
|
await model.load()
|
||||||
|
expect(model.title).toBe('Title of sdxl.safetensors')
|
||||||
|
expect(model.architecture_id).toBe('stable-diffusion-xl-base-v1')
|
||||||
|
expect(model.author).toBe('Author of sdxl.safetensors')
|
||||||
|
expect(model.description).toBe('Description of sdxl.safetensors')
|
||||||
|
expect(model.resolution).toBe('1024x1024')
|
||||||
|
expect(model.trigger_phrase).toBe('Trigger phrase of sdxl.safetensors')
|
||||||
|
expect(model.usage_hint).toBe('Usage hint of sdxl.safetensors')
|
||||||
|
expect(model.tags).toHaveLength(3)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle no metadata', async () => {
|
||||||
|
enableMocks()
|
||||||
|
const folderStore = await store.getModelsInFolderCached('checkpoints')
|
||||||
|
const model = folderStore.models['noinfo.safetensors']
|
||||||
|
await model.load()
|
||||||
|
expect(model.title).toBe('noinfo.safetensors')
|
||||||
|
expect(model.architecture_id).toBe('')
|
||||||
|
expect(model.author).toBe('')
|
||||||
|
expect(model.description).toBe('')
|
||||||
|
expect(model.resolution).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should cache model information', async () => {
|
||||||
|
enableMocks()
|
||||||
|
const folderStore1 = await store.getModelsInFolderCached('checkpoints')
|
||||||
|
const folderStore2 = await store.getModelsInFolderCached('checkpoints')
|
||||||
|
expect(api.getModels).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user