Convert pinia stores from options API to composition API (#1330)

* Convert toastStore

* Convert workspaceStateStore

* Convert settingStore

* Convert queueStore

* Convert modelToNodeStore

* Convert modelStore

* Convert dialogStore

* nit

* nit

* nit
This commit is contained in:
Chenlei Hu
2024-10-27 08:47:24 -04:00
committed by GitHub
parent 880437f3c0
commit fa9d944b32
7 changed files with 399 additions and 358 deletions

View File

@@ -2,59 +2,53 @@
// Currently we need to bridge between legacy app code and Vue app with a Pinia store.
import { defineStore } from 'pinia'
import { type Component, markRaw, nextTick } from 'vue'
interface DialogState {
isVisible: boolean
title: string
headerComponent: Component | null
component: Component | null
// Props passing to the component
props: Record<string, any>
// Props passing to the Dialog component
dialogComponentProps: DialogComponentProps
}
import { ref, shallowRef, type Component, markRaw } from 'vue'
interface DialogComponentProps {
maximizable?: boolean
onClose?: () => void
}
export const useDialogStore = defineStore('dialog', {
state: (): DialogState => ({
isVisible: false,
title: '',
headerComponent: null,
component: null,
props: {},
dialogComponentProps: {}
}),
export const useDialogStore = defineStore('dialog', () => {
const isVisible = ref(false)
const title = ref('')
const headerComponent = shallowRef<Component | null>(null)
const component = shallowRef<Component | null>(null)
const props = ref<Record<string, any>>({})
const dialogComponentProps = ref<DialogComponentProps>({})
actions: {
showDialog(options: {
title?: string
headerComponent?: Component
component: Component
props?: Record<string, any>
dialogComponentProps?: DialogComponentProps
}) {
this.isVisible = true
nextTick(() => {
this.title = options.title ?? ''
this.headerComponent = options.headerComponent
? markRaw(options.headerComponent)
: null
this.component = markRaw(options.component)
this.props = options.props || {}
this.dialogComponentProps = options.dialogComponentProps || {}
})
},
function showDialog(options: {
title?: string
headerComponent?: Component
component: Component
props?: Record<string, any>
dialogComponentProps?: DialogComponentProps
}) {
isVisible.value = true
title.value = options.title ?? ''
headerComponent.value = options.headerComponent
? markRaw(options.headerComponent)
: null
component.value = markRaw(options.component)
props.value = options.props || {}
dialogComponentProps.value = options.dialogComponentProps || {}
}
closeDialog() {
if (this.dialogComponentProps.onClose) {
this.dialogComponentProps.onClose()
}
this.isVisible = false
function closeDialog() {
if (dialogComponentProps.value.onClose) {
dialogComponentProps.value.onClose()
}
isVisible.value = false
}
return {
isVisible,
title,
headerComponent,
component,
props,
dialogComponentProps,
showDialog,
closeDialog
}
})

View File

@@ -1,5 +1,6 @@
import { api } from '@/scripts/api'
import { ref } from 'vue'
import { defineStore } from 'pinia'
import { api } from '@/scripts/api'
/** (Internal helper) finds a value in a metadata object from any of a list of keys. */
function _findInMetadata(metadata: any, ...keys: string[]): string | null {
@@ -158,39 +159,51 @@ export class ModelFolder {
const folderBlacklist = ['configs', 'custom_nodes']
/** Model store handler, wraps individual per-folder model stores */
export const useModelStore = defineStore('modelStore', {
state: () => ({
modelStoreMap: {} as Record<string, ModelFolder | null>,
isLoading: {} as Record<string, Promise<ModelFolder | null> | null>,
modelFolders: [] as string[]
}),
actions: {
async getModelsInFolderCached(folder: string): Promise<ModelFolder | null> {
if (folder in this.modelStoreMap) {
return this.modelStoreMap[folder]
}
if (this.isLoading[folder]) {
return this.isLoading[folder]
}
const promise = api.getModels(folder).then((models) => {
if (!models) {
return null
}
const store = new ModelFolder(folder, models)
this.modelStoreMap[folder] = store
this.isLoading[folder] = null
return store
})
this.isLoading[folder] = promise
return promise
},
clearCache() {
this.modelStoreMap = {}
},
async getModelFolders() {
this.modelFolders = (await api.getModelFolders()).filter(
(folder) => !folderBlacklist.includes(folder)
)
export const useModelStore = defineStore('modelStore', () => {
const modelStoreMap = ref<Record<string, ModelFolder | null>>({})
const isLoading = ref<Record<string, Promise<ModelFolder | null> | null>>({})
const modelFolders = ref<string[]>([])
async function getModelsInFolderCached(
folder: string
): Promise<ModelFolder | null> {
if (folder in modelStoreMap.value) {
return modelStoreMap.value[folder]
}
if (isLoading.value[folder]) {
return isLoading.value[folder]
}
const promise = api.getModels(folder).then((models) => {
if (!models) {
return null
}
const store = new ModelFolder(folder, models)
modelStoreMap.value[folder] = store
isLoading.value[folder] = null
return store
})
isLoading.value[folder] = promise
return promise
}
function clearCache() {
Object.keys(modelStoreMap.value).forEach((key) => {
delete modelStoreMap.value[key]
})
}
async function getModelFolders() {
modelFolders.value = (await api.getModelFolders()).filter(
(folder) => !folderBlacklist.includes(folder)
)
}
return {
modelStoreMap,
isLoading,
modelFolders,
getModelsInFolderCached,
clearCache,
getModelFolders
}
})

View File

@@ -1,6 +1,6 @@
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { useNodeDefStore } from '@/stores/nodeDefStore'
import { ref } from 'vue'
import { defineStore } from 'pinia'
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
/** Helper class that defines how to construct a node from a model. */
export class ModelNodeProvider {
@@ -17,75 +17,79 @@ export class ModelNodeProvider {
}
/** Service for mapping model types (by folder name) to nodes. */
export const useModelToNodeStore = defineStore('modelToNode', {
state: () => ({
modelToNodeMap: {} as Record<string, ModelNodeProvider[]>,
nodeDefStore: useNodeDefStore(),
haveDefaultsLoaded: false
}),
actions: {
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
* @returns The node provider for the given model type name.
*/
getNodeProvider(modelType: string): ModelNodeProvider {
this.registerDefaults()
return this.modelToNodeMap[modelType]?.[0]
},
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
getAllNodeProviders(modelType: string): ModelNodeProvider[] {
this.registerDefaults()
return this.modelToNodeMap[modelType] ?? []
},
/**
* Register a node provider for the given model type name.
* @param modelType The name of the model type to register the node provider for.
* @param nodeProvider The node provider to register.
*/
registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) {
this.registerDefaults()
this.modelToNodeMap[modelType] ??= []
this.modelToNodeMap[modelType].push(nodeProvider)
},
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
quickRegister(modelType: string, nodeClass: string, key: string) {
this.registerNodeProvider(
modelType,
new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key)
)
},
registerDefaults() {
if (this.haveDefaultsLoaded) {
return
}
if (Object.keys(this.nodeDefStore.nodeDefsByName).length === 0) {
return
}
this.haveDefaultsLoaded = true
this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
this.quickRegister(
'checkpoints',
'ImageOnlyCheckpointLoader',
'ckpt_name'
)
this.quickRegister('loras', 'LoraLoader', 'lora_name')
this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
this.quickRegister('vae', 'VAELoader', 'vae_name')
this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
export const useModelToNodeStore = defineStore('modelToNode', () => {
const modelToNodeMap = ref<Record<string, ModelNodeProvider[]>>({})
const nodeDefStore = useNodeDefStore()
const haveDefaultsLoaded = ref(false)
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
* @returns The node provider for the given model type name.
*/
function getNodeProvider(modelType: string): ModelNodeProvider | undefined {
registerDefaults()
return modelToNodeMap.value[modelType]?.[0]
}
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
function getAllNodeProviders(modelType: string): ModelNodeProvider[] {
registerDefaults()
return modelToNodeMap.value[modelType] ?? []
}
/**
* Register a node provider for the given model type name.
* @param modelType The name of the model type to register the node provider for.
* @param nodeProvider The node provider to register.
*/
function registerNodeProvider(
modelType: string,
nodeProvider: ModelNodeProvider
) {
registerDefaults()
if (!modelToNodeMap.value[modelType]) {
modelToNodeMap.value[modelType] = []
}
modelToNodeMap.value[modelType].push(nodeProvider)
}
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
function quickRegister(modelType: string, nodeClass: string, key: string) {
registerNodeProvider(
modelType,
new ModelNodeProvider(nodeDefStore.nodeDefsByName[nodeClass], key)
)
}
function registerDefaults() {
if (haveDefaultsLoaded.value) {
return
}
if (Object.keys(nodeDefStore.nodeDefsByName).length === 0) {
return
}
haveDefaultsLoaded.value = true
quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
quickRegister('checkpoints', 'ImageOnlyCheckpointLoader', 'ckpt_name')
quickRegister('loras', 'LoraLoader', 'lora_name')
quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
quickRegister('vae', 'VAELoader', 'vae_name')
quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
}
return {
modelToNodeMap,
getNodeProvider,
getAllNodeProviders,
registerNodeProvider,
quickRegister,
registerDefaults
}
})

View File

@@ -14,6 +14,7 @@ import type { ComfyWorkflowJSON, NodeId } from '@/types/comfyWorkflow'
import _ from 'lodash'
import { defineStore } from 'pinia'
import { toRaw } from 'vue'
import { ref, computed } from 'vue'
// Task type used in the API.
export type APITaskType = 'queue' | 'history'
@@ -327,98 +328,103 @@ export class TaskItemImpl {
}
}
interface State {
runningTasks: TaskItemImpl[]
pendingTasks: TaskItemImpl[]
historyTasks: TaskItemImpl[]
maxHistoryItems: number
isLoading: boolean
}
export const useQueueStore = defineStore('queue', () => {
const runningTasks = ref<TaskItemImpl[]>([])
const pendingTasks = ref<TaskItemImpl[]>([])
const historyTasks = ref<TaskItemImpl[]>([])
const maxHistoryItems = ref(64)
const isLoading = ref(false)
export const useQueueStore = defineStore('queue', {
state: (): State => ({
runningTasks: [],
pendingTasks: [],
historyTasks: [],
maxHistoryItems: 64,
isLoading: false
}),
getters: {
tasks(state) {
return [
...state.pendingTasks,
...state.runningTasks,
...state.historyTasks
]
},
flatTasks(): TaskItemImpl[] {
return this.tasks.flatMap((task: TaskItemImpl) => task.flatten())
},
lastHistoryQueueIndex(state) {
return state.historyTasks.length ? state.historyTasks[0].queueIndex : -1
},
hasPendingTasks(state) {
return state.pendingTasks.length > 0
}
},
actions: {
// Fetch the queue data from the API
async update() {
this.isLoading = true
try {
const [queue, history] = await Promise.all([
api.getQueue(),
api.getHistory(this.maxHistoryItems)
])
const tasks = computed(() => [
...pendingTasks.value,
...runningTasks.value,
...historyTasks.value
])
const toClassAll = (tasks: TaskItem[]): TaskItemImpl[] =>
tasks
.map(
(task: TaskItem) =>
new TaskItemImpl(
task.taskType,
task.prompt,
task['status'],
task['outputs'] || {}
)
)
// Desc order to show the latest tasks first
.sort((a, b) => b.queueIndex - a.queueIndex)
const flatTasks = computed(() =>
tasks.value.flatMap((task: TaskItemImpl) => task.flatten())
)
this.runningTasks = toClassAll(queue.Running)
this.pendingTasks = toClassAll(queue.Pending)
const lastHistoryQueueIndex = computed(() =>
historyTasks.value.length ? historyTasks.value[0].queueIndex : -1
)
// Process history items
const allIndex = new Set(
history.History.map((item: TaskItem) => item.prompt[0])
)
const newHistoryItems = toClassAll(
history.History.filter(
(item) => item.prompt[0] > this.lastHistoryQueueIndex
const hasPendingTasks = computed(() => pendingTasks.value.length > 0)
const update = async () => {
isLoading.value = true
try {
const [queue, history] = await Promise.all([
api.getQueue(),
api.getHistory(maxHistoryItems.value)
])
const toClassAll = (tasks: TaskItem[]): TaskItemImpl[] =>
tasks
.map(
(task: TaskItem) =>
new TaskItemImpl(
task.taskType,
task.prompt,
task['status'],
task['outputs'] || {}
)
)
)
const existingHistoryItems = this.historyTasks.filter(
(item: TaskItemImpl) => allIndex.has(item.queueIndex)
)
this.historyTasks = [...newHistoryItems, ...existingHistoryItems]
.slice(0, this.maxHistoryItems)
.sort((a, b) => b.queueIndex - a.queueIndex)
} finally {
this.isLoading = false
}
},
async clear(targets: ('queue' | 'history')[] = ['queue', 'history']) {
if (targets.length === 0) {
return
}
await Promise.all(targets.map((type) => api.clearItems(type)))
await this.update()
},
async delete(task: TaskItemImpl) {
await api.deleteItem(task.apiTaskType, task.promptId)
await this.update()
runningTasks.value = toClassAll(queue.Running)
pendingTasks.value = toClassAll(queue.Pending)
const allIndex = new Set(
history.History.map((item: TaskItem) => item.prompt[0])
)
const newHistoryItems = toClassAll(
history.History.filter(
(item) => item.prompt[0] > lastHistoryQueueIndex.value
)
)
const existingHistoryItems = historyTasks.value.filter(
(item: TaskItemImpl) => allIndex.has(item.queueIndex)
)
historyTasks.value = [...newHistoryItems, ...existingHistoryItems]
.slice(0, maxHistoryItems.value)
.sort((a, b) => b.queueIndex - a.queueIndex)
} finally {
isLoading.value = false
}
}
const clear = async (
targets: ('queue' | 'history')[] = ['queue', 'history']
) => {
if (targets.length === 0) {
return
}
await Promise.all(targets.map((type) => api.clearItems(type)))
await update()
}
const deleteTask = async (task: TaskItemImpl) => {
await api.deleteItem(task.apiTaskType, task.promptId)
await update()
}
return {
runningTasks,
pendingTasks,
historyTasks,
maxHistoryItems,
isLoading,
tasks,
flatTasks,
lastHistoryQueueIndex,
hasPendingTasks,
update,
clear,
delete: deleteTask
}
})
export const useQueuePendingTaskCountStore = defineStore(

View File

@@ -7,82 +7,83 @@
* settings directly updates the settingStore.settingValues.
*/
import { ref, computed } from 'vue'
import { defineStore } from 'pinia'
import { app } from '@/scripts/app'
import { ComfySettingsDialog } from '@/scripts/ui/settings'
import type { Settings } from '@/types/apiTypes'
import type { SettingParams } from '@/types/settingTypes'
import { buildTree } from '@/utils/treeUtil'
import { defineStore } from 'pinia'
import type { TreeNode } from 'primevue/treenode'
import type { ComfyExtension } from '@/types/comfy'
import { buildTree } from '@/utils/treeUtil'
import { CORE_SETTINGS } from '@/stores/coreSettings'
import { ComfyExtension } from '@/types/comfy'
export interface SettingTreeNode extends TreeNode {
data?: SettingParams
}
interface State {
settingValues: Record<string, any>
settings: Record<string, SettingParams>
}
export const useSettingStore = defineStore('setting', () => {
const settingValues = ref<Record<string, any>>({})
const settings = ref<Record<string, SettingParams>>({})
export const useSettingStore = defineStore('setting', {
state: (): State => ({
settingValues: {},
settings: {}
}),
getters: {
// Setting tree structure used for the settings dialog display.
settingTree(): SettingTreeNode {
const root = buildTree(
Object.values(this.settings).filter(
(setting: SettingParams) => setting.type !== 'hidden'
),
(setting: SettingParams) => setting.category || setting.id.split('.')
)
const settingTree = computed<SettingTreeNode>(() => {
const root = buildTree(
Object.values(settings.value).filter(
(setting: SettingParams) => setting.type !== 'hidden'
),
(setting: SettingParams) => setting.category || setting.id.split('.')
)
const floatingSettings = (root.children ?? []).filter((node) => node.leaf)
if (floatingSettings.length) {
root.children = (root.children ?? []).filter((node) => !node.leaf)
root.children.push({
key: 'Other',
label: 'Other',
leaf: false,
children: floatingSettings
})
}
return root
}
},
actions: {
addSettings(settings: ComfySettingsDialog) {
for (const id in settings.settingsLookup) {
const value = settings.getSettingValue(id)
this.settingValues[id] = value
}
this.settings = settings.settingsParamLookup
CORE_SETTINGS.forEach((setting: SettingParams) => {
settings.addSetting(setting)
const floatingSettings = (root.children ?? []).filter((node) => node.leaf)
if (floatingSettings.length) {
root.children = (root.children ?? []).filter((node) => !node.leaf)
root.children.push({
key: 'Other',
label: 'Other',
leaf: false,
children: floatingSettings
})
},
loadExtensionSettings(extension: ComfyExtension) {
extension.settings?.forEach((setting: SettingParams) => {
app.ui.settings.addSetting(setting)
})
},
async set<K extends keyof Settings>(key: K, value: Settings[K]) {
this.settingValues[key] = value
await app.ui.settings.setSettingValueAsync(key, value)
},
get<K extends keyof Settings>(key: K): Settings[K] {
return (
this.settingValues[key] ?? app.ui.settings.getSettingDefaultValue(key)
)
}
return root
})
function addSettings(settingsDialog: ComfySettingsDialog) {
for (const id in settingsDialog.settingsLookup) {
const value = settingsDialog.getSettingValue(id)
settingValues.value[id] = value
}
settings.value = settingsDialog.settingsParamLookup
CORE_SETTINGS.forEach((setting: SettingParams) => {
settingsDialog.addSetting(setting)
})
}
function loadExtensionSettings(extension: ComfyExtension) {
extension.settings?.forEach((setting: SettingParams) => {
app.ui.settings.addSetting(setting)
})
}
async function set<K extends keyof Settings>(key: K, value: Settings[K]) {
settingValues.value[key] = value
await app.ui.settings.setSettingValueAsync(key, value)
}
function get<K extends keyof Settings>(key: K): Settings[K] {
return (
settingValues.value[key] ?? app.ui.settings.getSettingDefaultValue(key)
)
}
return {
settingValues,
settings,
settingTree,
addSettings,
loadExtensionSettings,
set,
get
}
})

View File

@@ -2,27 +2,38 @@
// instead of going through the store.
// The store is useful when you need to call it from outside the Vue component context.
import { defineStore } from 'pinia'
import { ref } from 'vue'
import type { ToastMessageOptions } from 'primevue/toast'
export const useToastStore = defineStore('toast', {
state: () => ({
messagesToAdd: [] as ToastMessageOptions[],
messagesToRemove: [] as ToastMessageOptions[],
removeAllRequested: false
}),
export const useToastStore = defineStore('toast', () => {
const messagesToAdd = ref<ToastMessageOptions[]>([])
const messagesToRemove = ref<ToastMessageOptions[]>([])
const removeAllRequested = ref(false)
actions: {
add(message: ToastMessageOptions) {
this.messagesToAdd = [...this.messagesToAdd, message]
},
remove(message: ToastMessageOptions) {
this.messagesToRemove = [...this.messagesToRemove, message]
},
removeAll() {
this.removeAllRequested = true
},
addAlert(message: string) {
this.add({ severity: 'warn', summary: 'Alert', detail: message })
}
function add(message: ToastMessageOptions) {
messagesToAdd.value = [...messagesToAdd.value, message]
}
function remove(message: ToastMessageOptions) {
messagesToRemove.value = [...messagesToRemove.value, message]
}
function removeAll() {
removeAllRequested.value = true
}
function addAlert(message: string) {
add({ severity: 'warn', summary: 'Alert', detail: message })
}
return {
messagesToAdd,
messagesToRemove,
removeAllRequested,
add,
remove,
removeAll,
addAlert
}
})

View File

@@ -1,53 +1,65 @@
import type { SidebarTabExtension, ToastManager } from '@/types/extensionTypes'
import { ref, computed } from 'vue'
import { defineStore } from 'pinia'
import type { SidebarTabExtension, ToastManager } from '@/types/extensionTypes'
import { useToastStore } from './toastStore'
import { useQueueSettingsStore } from './queueStore'
import { useCommandStore } from './commandStore'
import { useSidebarTabStore } from './workspace/sidebarTabStore'
import { useSettingStore } from './settingStore'
interface WorkspaceState {
spinner: boolean
// Whether the shift key is down globally
shiftDown: boolean
}
export const useWorkspaceStore = defineStore('workspace', () => {
const spinner = ref(false)
const shiftDown = ref(false)
export const useWorkspaceStore = defineStore('workspace', {
state: (): WorkspaceState => ({
spinner: false,
shiftDown: false
}),
getters: {
toast(): ToastManager {
return useToastStore()
},
queueSettings() {
return useQueueSettingsStore()
},
command() {
return {
execute: useCommandStore().execute
}
},
sidebarTab() {
return useSidebarTabStore()
},
setting() {
return {
get: useSettingStore().get,
set: useSettingStore().set
}
}
},
actions: {
registerSidebarTab(tab: SidebarTabExtension) {
this.sidebarTab.registerSidebarTab(tab)
},
unregisterSidebarTab(id: string) {
this.sidebarTab.unregisterSidebarTab(id)
},
getSidebarTabs(): SidebarTabExtension[] {
return this.sidebarTab.sidebarTabs
}
const toast = computed<ToastManager>(() => useToastStore())
const queueSettings = computed(() => useQueueSettingsStore())
const command = computed(() => ({
execute: useCommandStore().execute
}))
const sidebarTab = computed(() => useSidebarTabStore())
const setting = computed(() => ({
get: useSettingStore().get,
set: useSettingStore().set
}))
/**
* Registers a sidebar tab.
* @param tab The sidebar tab to register.
* @deprecated Use `sidebarTab.registerSidebarTab` instead.
*/
function registerSidebarTab(tab: SidebarTabExtension) {
sidebarTab.value.registerSidebarTab(tab)
}
/**
* Unregisters a sidebar tab.
* @param id The id of the sidebar tab to unregister.
* @deprecated Use `sidebarTab.unregisterSidebarTab` instead.
*/
function unregisterSidebarTab(id: string) {
sidebarTab.value.unregisterSidebarTab(id)
}
/**
* Gets all registered sidebar tabs.
* @returns All registered sidebar tabs.
* @deprecated Use `sidebarTab.sidebarTabs` instead.
*/
function getSidebarTabs(): SidebarTabExtension[] {
return sidebarTab.value.sidebarTabs
}
return {
spinner,
shiftDown,
toast,
queueSettings,
command,
sidebarTab,
setting,
registerSidebarTab,
unregisterSidebarTab,
getSidebarTabs
}
})