diff --git a/src/scripts/api.ts b/src/scripts/api.ts index 868a15042..951494a1b 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -343,6 +343,19 @@ class ComfyApi extends EventTarget { return await res.json() } + /** + * Gets the metadata for a model + * @param {string} folder The folder containing the model + * @param {string} model The model to get metadata for + * @returns The metadata for the model + */ + async viewMetadata(folder: string, model: string) { + const res = await this.fetchApi( + `/view_metadata/${folder}?filename=${encodeURIComponent(model)}` + ) + return await res.json() + } + /** * Tells the server to download a model from the specified URL to the specified directory and filename * @param {string} url The URL to download the model from diff --git a/src/scripts/app.ts b/src/scripts/app.ts index edf579e3b..68e9e22c7 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -51,6 +51,7 @@ import { } from '@/services/dialogService' import { useSettingStore } from '@/stores/settingStore' import { useToastStore } from '@/stores/toastStore' +import { ModelStore, useModelStore } from '@/stores/modelStore' import type { ToastMessageOptions } from 'primevue/toast' import { useWorkspaceStore } from '@/stores/workspaceStateStore' import { LGraphGroup } from '@comfyorg/litegraph' @@ -137,7 +138,6 @@ export class ComfyApp { bodyBottom: HTMLElement canvasContainer: HTMLElement menu: ComfyAppMenu - modelsInFolderCache: Record constructor() { this.vueAppReady = false @@ -152,7 +152,6 @@ export class ComfyApp { parent: document.body }) this.menu = new ComfyAppMenu(this) - this.modelsInFolderCache = {} /** * List of extensions that are registered with the app @@ -2260,12 +2259,14 @@ export class ComfyApp { useSettingStore().get('Comfy.Workflow.ShowMissingModelsWarning') ) { for (let m of graphData.models) { - const models_available = await this.getModelsInFolderCached(m.directory) + const models_available = await useModelStore().getModelsInFolderCached( + m.directory + ) if (models_available === null) { // @ts-expect-error m.directory_invalid = true missingModels.push(m) - } else if (!models_available.includes(m.name)) { + } else if (!(m.name in models_available.models)) { missingModels.push(m) } } @@ -2860,19 +2861,6 @@ export class ComfyApp { app.graph.arrange() } - /** - * Gets the list of model names in a folder, using a temporary local cache - */ - async getModelsInFolderCached(folder: string): Promise { - if (folder in this.modelsInFolderCache) { - return this.modelsInFolderCache[folder] - } - // TODO: needs a lock to avoid overlapping calls - const models = await api.getModels(folder) - this.modelsInFolderCache[folder] = models - return models - } - /** * Registers a Comfy web extension with the app * @param {ComfyExtension} extension @@ -2901,9 +2889,10 @@ export class ComfyApp { summary: 'Update', detail: 'Update requested' } - if (this.vueAppReady) useToastStore().add(requestToastMessage) - - this.modelsInFolderCache = {} + if (this.vueAppReady) { + useToastStore().add(requestToastMessage) + useModelStore().clearCache() + } const defs = await api.getNodeDefs() diff --git a/src/stores/modelStore.ts b/src/stores/modelStore.ts new file mode 100644 index 000000000..490bad010 --- /dev/null +++ b/src/stores/modelStore.ts @@ -0,0 +1,136 @@ +import { api } from '@/scripts/api' +import { defineStore } from 'pinia' + +/** (Internal helper) finds a value in a metadata object from any of a list of keys. */ +function _findInMetadata(metadata: any, ...keys: string[]): string | null { + for (const key of keys) { + if (key in metadata) { + return metadata[key] + } + for (const k in metadata) { + if (k.endsWith(key)) { + return metadata[k] + } + } + } + return null +} + +/** Defines and holds metadata for a model */ +export class ComfyModelDef { + /** Proper filename of the model */ + name: string = '' + /** Directory containing the model, eg 'checkpoints' */ + directory: string = '' + /** Title / display name of the model, sometimes same as the name but not always */ + title: string = '' + /** Metadata: architecture ID for the model, such as 'stable-diffusion-xl-v1-base' */ + architecture_id: string = '' + /** Metadata: author of the model */ + author: string = '' + /** Metadata: resolution of the model, eg '1024x1024' */ + resolution: string = '' + /** Metadata: description of the model */ + description: string = '' + /** Metadata: usage hint for the model */ + usage_hint: string = '' + /** Metadata: trigger phrase for the model */ + trigger_phrase: string = '' + /** Metadata: tags list for the model */ + tags: string[] = [] + /** Metadata: image for the model */ + image: string = '' + /** Whether the model metadata has been loaded from the server, used for `load()` */ + has_loaded_metadata: boolean = false + + constructor(name: string, directory: string) { + this.name = name + this.title = name + this.directory = directory + } + + /** Loads the model metadata from the server, filling in this object if data is available */ + async load(): Promise { + if (this.has_loaded_metadata) { + return + } + const metadata = await api.viewMetadata(this.directory, this.name) + if (!metadata) { + return + } + this.title = + _findInMetadata( + metadata, + 'modelspec.title', + 'title', + 'display_name', + 'name' + ) || this.name + this.architecture_id = + _findInMetadata(metadata, 'modelspec.architecture', 'architecture') || '' + this.author = _findInMetadata(metadata, 'modelspec.author', 'author') || '' + this.description = + _findInMetadata(metadata, 'modelspec.description', 'description') || '' + this.resolution = + _findInMetadata(metadata, 'modelspec.resolution', 'resolution') || '' + this.usage_hint = + _findInMetadata(metadata, 'modelspec.usage_hint', 'usage_hint') || '' + this.trigger_phrase = + _findInMetadata(metadata, 'modelspec.trigger_phrase', 'trigger_phrase') || + '' + this.image = + _findInMetadata( + metadata, + 'modelspec.thumbnail', + 'thumbnail', + 'image', + 'icon' + ) || '' + const tagsCommaSeparated = + _findInMetadata(metadata, 'modelspec.tags', 'tags') || '' + this.tags = tagsCommaSeparated.split(',').map((tag) => tag.trim()) + this.has_loaded_metadata = true + } +} + +/** Model store for a folder */ +export class ModelStore { + models: Record = {} + + constructor(directory: string, models: string[]) { + for (const model of models) { + this.models[model] = new ComfyModelDef(model, directory) + } + } + + async loadModelMetadata(modelName: string) { + if (this.models[modelName]) { + await this.models[modelName].load() + } + } +} + +/** Model store handler, wraps individual per-folder model stores */ +export const useModelStore = defineStore('modelStore', { + state: () => ({ + modelStoreMap: {} as Record + }), + actions: { + async getModelsInFolderCached(folder: string): Promise { + if (folder in this.modelStoreMap) { + return this.modelStoreMap[folder] + } + // TODO: needs a lock to avoid overlapping calls + const models = await api.getModels(folder) + if (!models) { + return null + } + const store = new ModelStore(folder, models) + this.modelStoreMap[folder] = store + return store + }, + clearCache() { + this.modelStoreMap = {} + } + } +}) diff --git a/tests-ui/tests/store/modelStore.test.ts b/tests-ui/tests/store/modelStore.test.ts new file mode 100644 index 000000000..d488f2ddd --- /dev/null +++ b/tests-ui/tests/store/modelStore.test.ts @@ -0,0 +1,85 @@ +import { setActivePinia, createPinia } from 'pinia' +import { useModelStore } from '@/stores/modelStore' +import { api } from '@/scripts/api' + +// Mock the api +jest.mock('@/scripts/api', () => ({ + api: { + getModels: jest.fn(), + viewMetadata: jest.fn() + } +})) + +function enableMocks() { + ;(api.getModels as jest.Mock).mockResolvedValue([ + 'sdxl.safetensors', + 'sdv15.safetensors', + 'noinfo.safetensors' + ]) + ;(api.viewMetadata as jest.Mock).mockImplementation((_, model) => { + if (model === 'noinfo.safetensors') { + return Promise.resolve({}) + } + return Promise.resolve({ + 'modelspec.title': `Title of ${model}`, + display_name: 'Should not show', + 'modelspec.architecture': 'stable-diffusion-xl-base-v1', + 'modelspec.author': `Author of ${model}`, + 'modelspec.description': `Description of ${model}`, + 'modelspec.resolution': '1024x1024', + trigger_phrase: `Trigger phrase of ${model}`, + usage_hint: `Usage hint of ${model}`, + tags: `tags,for,${model}` + }) + }) +} + +describe('useModelStore', () => { + let store: ReturnType + + beforeEach(() => { + setActivePinia(createPinia()) + store = useModelStore() + }) + + it('should load models', async () => { + enableMocks() + const folderStore = await store.getModelsInFolderCached('checkpoints') + expect(folderStore).not.toBeNull() + expect(Object.keys(folderStore.models)).toHaveLength(3) + }) + + it('should load model metadata', async () => { + enableMocks() + const folderStore = await store.getModelsInFolderCached('checkpoints') + const model = folderStore.models['sdxl.safetensors'] + await model.load() + expect(model.title).toBe('Title of sdxl.safetensors') + expect(model.architecture_id).toBe('stable-diffusion-xl-base-v1') + expect(model.author).toBe('Author of sdxl.safetensors') + expect(model.description).toBe('Description of sdxl.safetensors') + expect(model.resolution).toBe('1024x1024') + expect(model.trigger_phrase).toBe('Trigger phrase of sdxl.safetensors') + expect(model.usage_hint).toBe('Usage hint of sdxl.safetensors') + expect(model.tags).toHaveLength(3) + }) + + it('should handle no metadata', async () => { + enableMocks() + const folderStore = await store.getModelsInFolderCached('checkpoints') + const model = folderStore.models['noinfo.safetensors'] + await model.load() + expect(model.title).toBe('noinfo.safetensors') + expect(model.architecture_id).toBe('') + expect(model.author).toBe('') + expect(model.description).toBe('') + expect(model.resolution).toBe('') + }) + + it('should cache model information', async () => { + enableMocks() + const folderStore1 = await store.getModelsInFolderCached('checkpoints') + const folderStore2 = await store.getModelsInFolderCached('checkpoints') + expect(api.getModels).toHaveBeenCalledTimes(1) + }) +})