Allow dragging model library outputs onto existing nodes (#1004)

* allow multiple compatible node registrations for model type

* allow dragging model library outputs onto existing nodes

* easier registration

* add alt loaders for checkpoint and lora
This commit is contained in:
Alex "mcmonkey" Goodwin
2024-10-02 12:53:19 -07:00
committed by GitHub
parent a737be7e16
commit b3a624a572
2 changed files with 65 additions and 39 deletions

View File

@@ -47,7 +47,10 @@ import type { RenderedTreeExplorerNode } from '@/types/treeExplorerTypes'
import { useNodeBookmarkStore } from '@/stores/nodeBookmarkStore'
import { useCanvasStore } from '@/stores/graphStore'
import { ComfyModelDef } from '@/stores/modelStore'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
import {
ModelNodeProvider,
useModelToNodeStore
} from '@/stores/modelToNodeStore'
import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue'
const emit = defineEmits(['ready'])
@@ -160,15 +163,33 @@ onMounted(async () => {
comfyApp.addNodeOnGraph(nodeDef, { pos })
} else if (node.data instanceof ComfyModelDef) {
const model = node.data
const provider = modelToNodeStore.getNodeProvider(model.directory)
if (provider) {
const pos = comfyApp.clientPosToCanvasPos([
loc.clientX - 20,
loc.clientY
])
const node = comfyApp.addNodeOnGraph(provider.nodeDef, { pos })
const widget = node.widgets.find(
(widget) => widget.name === provider.key
const pos = comfyApp.clientPosToCanvasPos([loc.clientX, loc.clientY])
const nodeAtPos = comfyApp.graph.getNodeOnPos(pos[0], pos[1])
let targetProvider: ModelNodeProvider | null = null
let targetGraphNode: LGraphNode | null = null
if (nodeAtPos) {
const providers = modelToNodeStore.getAllNodeProviders(
model.directory
)
for (const provider of providers) {
if (provider.nodeDef.name === nodeAtPos.comfyClass) {
targetGraphNode = nodeAtPos
targetProvider = provider
}
}
}
if (!targetGraphNode) {
const provider = modelToNodeStore.getNodeProvider(model.directory)
if (provider) {
targetGraphNode = comfyApp.addNodeOnGraph(provider.nodeDef, {
pos
})
targetProvider = provider
}
}
if (targetGraphNode) {
const widget = targetGraphNode.widgets.find(
(widget) => widget.name === targetProvider.key
)
if (widget) {
widget.value = model.file_name

View File

@@ -1,7 +1,6 @@
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { useNodeDefStore } from '@/stores/nodeDefStore'
import { defineStore } from 'pinia'
import { toRaw } from 'vue'
/** Helper class that defines how to construct a node from a model. */
export class ModelNodeProvider {
@@ -20,7 +19,7 @@ export class ModelNodeProvider {
/** Service for mapping model types (by folder name) to nodes. */
export const useModelToNodeStore = defineStore('modelToNode', {
state: () => ({
modelToNodeMap: {} as Record<string, ModelNodeProvider>,
modelToNodeMap: {} as Record<string, ModelNodeProvider[]>,
nodeDefStore: useNodeDefStore(),
haveDefaultsLoaded: false
}),
@@ -31,6 +30,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
* @returns The node provider for the given model type name.
*/
getNodeProvider(modelType: string): ModelNodeProvider {
this.registerDefaults()
return this.modelToNodeMap[modelType]?.[0]
},
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
getAllNodeProviders(modelType: string): ModelNodeProvider[] {
this.registerDefaults()
return this.modelToNodeMap[modelType]
},
@@ -42,7 +51,21 @@ export const useModelToNodeStore = defineStore('modelToNode', {
*/
registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) {
this.registerDefaults()
this.modelToNodeMap[modelType] = nodeProvider
this.modelToNodeMap[modelType] ??= []
this.modelToNodeMap[modelType].push(nodeProvider)
},
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
quickRegister(modelType: string, nodeClass: string, key: string) {
this.registerNodeProvider(
modelType,
new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key)
)
},
registerDefaults() {
@@ -53,34 +76,16 @@ export const useModelToNodeStore = defineStore('modelToNode', {
return
}
this.haveDefaultsLoaded = true
this.registerNodeProvider(
this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
this.quickRegister(
'checkpoints',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['CheckpointLoaderSimple'],
'ckpt_name'
)
)
this.registerNodeProvider(
'loras',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['LoraLoader'],
'lora_name'
)
)
this.registerNodeProvider(
'vae',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['VAELoader'],
'vae_name'
)
)
this.registerNodeProvider(
'controlnet',
new ModelNodeProvider(
this.nodeDefStore.nodeDefsByName['ControlNetLoader'],
'control_net_name'
)
'ImageOnlyCheckpointLoader',
'ckpt_name'
)
this.quickRegister('loras', 'LoraLoader', 'lora_name')
this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
this.quickRegister('vae', 'VAELoader', 'vae_name')
this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
}
}
})