Allow dragging model library outputs onto existing nodes (#1004)

* allow multiple compatible node registrations for model type

* allow dragging model library outputs onto existing nodes

* easier registration

* add alt loaders for checkpoint and lora
This commit is contained in:
Alex "mcmonkey" Goodwin
2024-10-02 12:53:19 -07:00
committed by GitHub
parent a737be7e16
commit b3a624a572
2 changed files with 65 additions and 39 deletions

View File

@@ -1,7 +1,6 @@
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { useNodeDefStore } from '@/stores/nodeDefStore'
import { defineStore } from 'pinia'
import { toRaw } from 'vue'
/** Helper class that defines how to construct a node from a model. */
export class ModelNodeProvider {
@@ -20,7 +19,7 @@ export class ModelNodeProvider {
/** Service for mapping model types (by folder name) to nodes. */
export const useModelToNodeStore = defineStore('modelToNode', {
state: () => ({
modelToNodeMap: {} as Record<string, ModelNodeProvider>,
modelToNodeMap: {} as Record<string, ModelNodeProvider[]>,
nodeDefStore: useNodeDefStore(),
haveDefaultsLoaded: false
}),
@@ -31,6 +30,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
* @returns The node provider for the given model type name.
*/
getNodeProvider(modelType: string): ModelNodeProvider {
this.registerDefaults()
return this.modelToNodeMap[modelType]?.[0]
},
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
getAllNodeProviders(modelType: string): ModelNodeProvider[] {
this.registerDefaults()
return this.modelToNodeMap[modelType]
},
@@ -42,7 +51,21 @@ export const useModelToNodeStore = defineStore('modelToNode', {
*/
registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) {
this.registerDefaults()
this.modelToNodeMap[modelType] = nodeProvider
this.modelToNodeMap[modelType] ??= []
this.modelToNodeMap[modelType].push(nodeProvider)
},
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
quickRegister(modelType: string, nodeClass: string, key: string) {
this.registerNodeProvider(
modelType,
new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key)
)
},
registerDefaults() {
@@ -53,34 +76,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
return
}
this.haveDefaultsLoaded = true
this.registerNodeProvider(
this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
this.quickRegister(
'checkpoints',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['CheckpointLoaderSimple'],
'ckpt_name'
)
)
this.registerNodeProvider(
'loras',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['LoraLoader'],
'lora_name'
)
)
this.registerNodeProvider(
'vae',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['VAELoader'],
'vae_name'
)
)
this.registerNodeProvider(
'controlnet',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['ControlNetLoader'],
'control_net_name'
)
'ImageOnlyCheckpointLoader',
'ckpt_name'
)
this.quickRegister('loras', 'LoraLoader', 'lora_name')
this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
this.quickRegister('vae', 'VAELoader', 'vae_name')
this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
}
}
})