Allow dragging model library outputs onto existing nodes (#1004)

* allow multiple compatible node registrations for model type

* allow dragging model library outputs onto existing nodes

* easier registration

* add alt loaders for checkpoint and lora
This commit is contained in:
Alex "mcmonkey" Goodwin
2024-10-02 12:53:19 -07:00
committed by GitHub
parent a737be7e16
commit b3a624a572
2 changed files with 65 additions and 39 deletions

View File

@@ -47,7 +47,10 @@ import type { RenderedTreeExplorerNode } from '@/types/treeExplorerTypes'
import { useNodeBookmarkStore } from '@/stores/nodeBookmarkStore' import { useNodeBookmarkStore } from '@/stores/nodeBookmarkStore'
import { useCanvasStore } from '@/stores/graphStore' import { useCanvasStore } from '@/stores/graphStore'
import { ComfyModelDef } from '@/stores/modelStore' import { ComfyModelDef } from '@/stores/modelStore'
import { useModelToNodeStore } from '@/stores/modelToNodeStore' import {
ModelNodeProvider,
useModelToNodeStore
} from '@/stores/modelToNodeStore'
import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue' import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue'
const emit = defineEmits(['ready']) const emit = defineEmits(['ready'])
@@ -160,15 +163,33 @@ onMounted(async () => {
comfyApp.addNodeOnGraph(nodeDef, { pos }) comfyApp.addNodeOnGraph(nodeDef, { pos })
} else if (node.data instanceof ComfyModelDef) { } else if (node.data instanceof ComfyModelDef) {
const model = node.data const model = node.data
const provider = modelToNodeStore.getNodeProvider(model.directory) const pos = comfyApp.clientPosToCanvasPos([loc.clientX, loc.clientY])
if (provider) { const nodeAtPos = comfyApp.graph.getNodeOnPos(pos[0], pos[1])
const pos = comfyApp.clientPosToCanvasPos([ let targetProvider: ModelNodeProvider | null = null
loc.clientX - 20, let targetGraphNode: LGraphNode | null = null
loc.clientY if (nodeAtPos) {
]) const providers = modelToNodeStore.getAllNodeProviders(
const node = comfyApp.addNodeOnGraph(provider.nodeDef, { pos }) model.directory
const widget = node.widgets.find( )
(widget) => widget.name === provider.key for (const provider of providers) {
if (provider.nodeDef.name === nodeAtPos.comfyClass) {
targetGraphNode = nodeAtPos
targetProvider = provider
}
}
}
if (!targetGraphNode) {
const provider = modelToNodeStore.getNodeProvider(model.directory)
if (provider) {
targetGraphNode = comfyApp.addNodeOnGraph(provider.nodeDef, {
pos
})
targetProvider = provider
}
}
if (targetGraphNode) {
const widget = targetGraphNode.widgets.find(
(widget) => widget.name === targetProvider.key
) )
if (widget) { if (widget) {
widget.value = model.file_name widget.value = model.file_name

View File

@@ -1,7 +1,6 @@
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore' import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { useNodeDefStore } from '@/stores/nodeDefStore' import { useNodeDefStore } from '@/stores/nodeDefStore'
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
import { toRaw } from 'vue'
/** Helper class that defines how to construct a node from a model. */ /** Helper class that defines how to construct a node from a model. */
export class ModelNodeProvider { export class ModelNodeProvider {
@@ -20,7 +19,7 @@ export class ModelNodeProvider {
/** Service for mapping model types (by folder name) to nodes. */ /** Service for mapping model types (by folder name) to nodes. */
export const useModelToNodeStore = defineStore('modelToNode', { export const useModelToNodeStore = defineStore('modelToNode', {
state: () => ({ state: () => ({
modelToNodeMap: {} as Record<string, ModelNodeProvider>, modelToNodeMap: {} as Record<string, ModelNodeProvider[]>,
nodeDefStore: useNodeDefStore(), nodeDefStore: useNodeDefStore(),
haveDefaultsLoaded: false haveDefaultsLoaded: false
}), }),
@@ -31,6 +30,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
* @returns The node provider for the given model type name. * @returns The node provider for the given model type name.
*/ */
getNodeProvider(modelType: string): ModelNodeProvider { getNodeProvider(modelType: string): ModelNodeProvider {
this.registerDefaults()
return this.modelToNodeMap[modelType]?.[0]
},
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
getAllNodeProviders(modelType: string): ModelNodeProvider[] {
this.registerDefaults() this.registerDefaults()
return this.modelToNodeMap[modelType] return this.modelToNodeMap[modelType]
}, },
@@ -42,7 +51,21 @@ export const useModelToNodeStore = defineStore('modelToNode', {
*/ */
registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) { registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) {
this.registerDefaults() this.registerDefaults()
this.modelToNodeMap[modelType] = nodeProvider this.modelToNodeMap[modelType] ??= []
this.modelToNodeMap[modelType].push(nodeProvider)
},
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
quickRegister(modelType: string, nodeClass: string, key: string) {
this.registerNodeProvider(
modelType,
new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key)
)
}, },
registerDefaults() { registerDefaults() {
@@ -53,34 +76,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
return return
} }
this.haveDefaultsLoaded = true this.haveDefaultsLoaded = true
this.registerNodeProvider( this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
this.quickRegister(
'checkpoints', 'checkpoints',
new ModelNodeProvider( 'ImageOnlyCheckpointLoader',
this.nodeDefStore.nodeDefsByName['CheckpointLoaderSimple'], 'ckpt_name'
'ckpt_name'
)
)
this.registerNodeProvider(
'loras',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['LoraLoader'],
'lora_name'
)
)
this.registerNodeProvider(
'vae',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['VAELoader'],
'vae_name'
)
)
this.registerNodeProvider(
'controlnet',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['ControlNetLoader'],
'control_net_name'
)
) )
this.quickRegister('loras', 'LoraLoader', 'lora_name')
this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
this.quickRegister('vae', 'VAELoader', 'vae_name')
this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
} }
} }
}) })