any
-}
-
-export function useNodeEventHandlers(nodeManager: Ref
) {
+function useNodeEventHandlersIndividual() {
const canvasStore = useCanvasStore()
+ const { nodeManager } = useVueNodeLifecycle()
const { bringNodeToFront } = useNodeZIndex()
const { shouldHandleNodePointerEvents } = useCanvasInteractions()
@@ -40,7 +38,7 @@ export function useNodeEventHandlers(nodeManager: Ref) {
const node = nodeManager.value.getNode(nodeData.id)
if (!node) return
- const isMultiSelect = event.ctrlKey || event.metaKey
+ const isMultiSelect = event.ctrlKey || event.metaKey || event.shiftKey
if (isMultiSelect) {
// Ctrl/Cmd+click -> toggle selection
@@ -84,6 +82,7 @@ export function useNodeEventHandlers(nodeManager: Ref) {
const currentCollapsed = node.flags?.collapsed ?? false
if (currentCollapsed !== collapsed) {
node.collapse()
+ nodeManager.value.scheduleUpdate(nodeId, 'critical')
}
}
@@ -237,3 +236,7 @@ export function useNodeEventHandlers(nodeManager: Ref) {
deselectNodes
}
}
+
+export const useNodeEventHandlers = createSharedComposable(
+ useNodeEventHandlersIndividual
+)
diff --git a/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts b/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts
new file mode 100644
index 000000000..f5ba08374
--- /dev/null
+++ b/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts
@@ -0,0 +1,93 @@
+import { type MaybeRefOrGetter, computed, ref, toValue } from 'vue'
+
+import type { VueNodeData } from '@/composables/graph/useGraphNodeManager'
+import { useCanvasInteractions } from '@/renderer/core/canvas/useCanvasInteractions'
+import { layoutStore } from '@/renderer/core/layout/store/layoutStore'
+import { useNodeLayout } from '@/renderer/extensions/vueNodes/layout/useNodeLayout'
+
+// Treat tiny pointer jitter as a click, not a drag
+const DRAG_THRESHOLD_PX = 4
+
+export function useNodePointerInteractions(
+ nodeDataMaybe: MaybeRefOrGetter,
+ onPointerUp: (
+ event: PointerEvent,
+ nodeData: VueNodeData,
+ wasDragging: boolean
+ ) => void
+) {
+ const nodeData = toValue(nodeDataMaybe)
+
+ const { startDrag, endDrag, handleDrag } = useNodeLayout(nodeData.id)
+ // Use canvas interactions for proper wheel event handling and pointer event capture control
+ const { forwardEventToCanvas, shouldHandleNodePointerEvents } =
+ useCanvasInteractions()
+
+ // Drag state for styling
+ const isDragging = ref(false)
+ const dragStyle = computed(() => ({
+ cursor: isDragging.value ? 'grabbing' : 'grab'
+ }))
+ const lastX = ref(0)
+ const lastY = ref(0)
+
+ const handlePointerDown = (event: PointerEvent) => {
+ if (!nodeData) {
+ console.warn(
+ 'LGraphNode: nodeData is null/undefined in handlePointerDown'
+ )
+ return
+ }
+
+ // Don't handle pointer events when canvas is in panning mode - forward to canvas instead
+ if (!shouldHandleNodePointerEvents.value) {
+ forwardEventToCanvas(event)
+ return
+ }
+
+ // Start drag using layout system
+ isDragging.value = true
+
+ // Set Vue node dragging state for selection toolbox
+ layoutStore.isDraggingVueNodes.value = true
+
+ startDrag(event)
+ lastY.value = event.clientY
+ lastX.value = event.clientX
+ }
+
+ const handlePointerMove = (event: PointerEvent) => {
+ if (isDragging.value) {
+ void handleDrag(event)
+ }
+ }
+
+ const handlePointerUp = (event: PointerEvent) => {
+ if (isDragging.value) {
+ isDragging.value = false
+ void endDrag(event)
+
+ // Clear Vue node dragging state for selection toolbox
+ layoutStore.isDraggingVueNodes.value = false
+ }
+
+ // Don't emit node-click when canvas is in panning mode - forward to canvas instead
+ if (!shouldHandleNodePointerEvents.value) {
+ forwardEventToCanvas(event)
+ return
+ }
+
+ // Emit node-click for selection handling in GraphCanvas
+ const dx = event.clientX - lastX.value
+ const dy = event.clientY - lastY.value
+ const wasDragging = Math.hypot(dx, dy) > DRAG_THRESHOLD_PX
+ onPointerUp(event, nodeData, wasDragging)
+ }
+ return {
+ isDragging,
+ dragStyle,
+ handlePointerMove,
+ handlePointerDown,
+ handlePointerUp
+ }
+}
diff --git a/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts b/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts
index 0dd9922ef..034047471 100644
--- a/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts
+++ b/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts
@@ -93,10 +93,10 @@ export function useNodeTooltips(
pt: {
text: {
class:
- 'bg-charcoal-100 border border-slate-300 rounded-md px-4 py-2 text-white text-sm font-normal leading-tight max-w-75 shadow-none'
+ 'bg-charcoal-800 border border-slate-300 rounded-md px-4 py-2 text-white text-sm font-normal leading-tight max-w-75 shadow-none'
},
arrow: {
- class: 'before:border-charcoal-100'
+ class: 'before:border-slate-300'
}
}
}
diff --git a/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts b/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts
index 4ce9f8e62..e8c38164d 100644
--- a/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts
+++ b/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts
@@ -8,7 +8,13 @@
* Supports different element types (nodes, slots, widgets, etc.) with
* customizable data attributes and update handlers.
*/
-import { getCurrentInstance, onMounted, onUnmounted } from 'vue'
+import {
+ type MaybeRefOrGetter,
+ getCurrentInstance,
+ onMounted,
+ onUnmounted,
+ toValue
+} from 'vue'
import { useSharedCanvasPositionConversion } from '@/composables/element/useCanvasPositionConversion'
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
@@ -154,9 +160,10 @@ const resizeObserver = new ResizeObserver((entries) => {
* ```
*/
export function useVueElementTracking(
- appIdentifier: string,
+ appIdentifierMaybe: MaybeRefOrGetter,
trackingType: string
) {
+ const appIdentifier = toValue(appIdentifierMaybe)
onMounted(() => {
const element = getCurrentInstance()?.proxy?.$el
if (!(element instanceof HTMLElement) || !appIdentifier) return
diff --git a/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts b/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts
deleted file mode 100644
index aae08298a..000000000
--- a/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts
+++ /dev/null
@@ -1,36 +0,0 @@
-import { storeToRefs } from 'pinia'
-import { computed, provide } from 'vue'
-
-import {
- ExecutingNodeIdsKey,
- NodeProgressStatesKey
-} from '@/renderer/core/canvas/injectionKeys'
-import { useExecutionStore } from '@/stores/executionStore'
-
-/**
- * Composable for providing execution state to Vue node children
- *
- * This composable sets up the execution state providers that can be injected
- * by child Vue nodes using useNodeExecutionState.
- *
- * Should be used in the parent component that manages Vue nodes (e.g., GraphCanvas).
- */
-export const useExecutionStateProvider = () => {
- const executionStore = useExecutionStore()
- const { executingNodeIds: storeExecutingNodeIds, nodeProgressStates } =
- storeToRefs(executionStore)
-
- // Convert execution store data to the format expected by Vue nodes
- const executingNodeIds = computed(
- () => new Set(storeExecutingNodeIds.value.map(String))
- )
-
- // Provide the execution state to all child Vue nodes
- provide(ExecutingNodeIdsKey, executingNodeIds)
- provide(NodeProgressStatesKey, nodeProgressStates)
-
- return {
- executingNodeIds,
- nodeProgressStates
- }
-}
diff --git a/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts b/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts
index 8f03e29e1..aa4867db9 100644
--- a/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts
+++ b/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts
@@ -1,10 +1,7 @@
-import { computed, inject, ref } from 'vue'
+import { storeToRefs } from 'pinia'
+import { type MaybeRefOrGetter, computed, toValue } from 'vue'
-import {
- ExecutingNodeIdsKey,
- NodeProgressStatesKey
-} from '@/renderer/core/canvas/injectionKeys'
-import type { NodeProgressState } from '@/schemas/apiSchema'
+import { useExecutionStore } from '@/stores/executionStore'
/**
* Composable for managing execution state of Vue-based nodes
@@ -12,18 +9,18 @@ import type { NodeProgressState } from '@/schemas/apiSchema'
* Provides reactive access to execution state and progress for a specific node
* by injecting execution data from the parent GraphCanvas provider.
*
- * @param nodeId - The ID of the node to track execution state for
+ * @param nodeIdMaybe - The ID of the node to track execution state for
* @returns Object containing reactive execution state and progress
*/
-export const useNodeExecutionState = (nodeId: string) => {
- const executingNodeIds = inject(ExecutingNodeIdsKey, ref(new Set()))
- const nodeProgressStates = inject(
- NodeProgressStatesKey,
- ref>({})
- )
+export const useNodeExecutionState = (
+ nodeIdMaybe: MaybeRefOrGetter
+) => {
+ const nodeId = toValue(nodeIdMaybe)
+ const { uniqueExecutingNodeIdStrings, nodeProgressStates } =
+ storeToRefs(useExecutionStore())
const executing = computed(() => {
- return executingNodeIds.value.has(nodeId)
+ return uniqueExecutingNodeIdStrings.value.has(nodeId)
})
const progress = computed(() => {
diff --git a/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts b/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts
index 18a085641..89718eb8d 100644
--- a/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts
+++ b/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts
@@ -1,12 +1,7 @@
-/**
- * Composable for individual Vue node components
- *
- * Uses customRef for shared write access with Canvas renderer.
- * Provides dragging functionality and reactive layout state.
- */
-import { computed, inject } from 'vue'
+import { storeToRefs } from 'pinia'
+import { type MaybeRefOrGetter, computed, inject, toValue } from 'vue'
-import { SelectedNodeIdsKey } from '@/renderer/core/canvas/injectionKeys'
+import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
import { TransformStateKey } from '@/renderer/core/layout/injectionKeys'
import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations'
import { layoutStore } from '@/renderer/core/layout/store/layoutStore'
@@ -16,15 +11,16 @@ import { LayoutSource, type Point } from '@/renderer/core/layout/types'
* Composable for individual Vue node components
* Uses customRef for shared write access with Canvas renderer
*/
-export function useNodeLayout(nodeId: string) {
- const store = layoutStore
+export function useNodeLayout(nodeIdMaybe: MaybeRefOrGetter) {
+ const nodeId = toValue(nodeIdMaybe)
const mutations = useLayoutMutations()
+ const { selectedNodeIds } = storeToRefs(useCanvasStore())
// Get transform utilities from TransformPane if available
const transformState = inject(TransformStateKey)
// Get the customRef for this node (shared write access)
- const layoutRef = store.getNodeLayoutRef(nodeId)
+ const layoutRef = layoutStore.getNodeLayoutRef(nodeId)
// Computed properties for easy access
const position = computed(() => {
@@ -53,8 +49,6 @@ export function useNodeLayout(nodeId: string) {
let dragStartMouse: Point | null = null
let otherSelectedNodesStartPositions: Map | null = null
- const selectedNodeIds = inject(SelectedNodeIdsKey, null)
-
/**
* Start dragging the node
*/
diff --git a/src/renderer/extensions/vueNodes/lod/useLOD.ts b/src/renderer/extensions/vueNodes/lod/useLOD.ts
index 584e21f9a..87c1bb865 100644
--- a/src/renderer/extensions/vueNodes/lod/useLOD.ts
+++ b/src/renderer/extensions/vueNodes/lod/useLOD.ts
@@ -27,7 +27,7 @@
*
* ```
*/
-import { type Ref, computed, readonly } from 'vue'
+import { type MaybeRefOrGetter, computed, readonly, toRef } from 'vue'
export enum LODLevel {
MINIMAL = 'minimal', // zoom <= 0.4
@@ -78,7 +78,8 @@ const LOD_CONFIGS: Record = {
* @param zoomRef - Reactive reference to current zoom level (camera.z)
* @returns LOD state and configuration
*/
-export function useLOD(zoomRef: Ref) {
+export function useLOD(zoomRefMaybe: MaybeRefOrGetter) {
+ const zoomRef = toRef(zoomRefMaybe)
// Continuous LOD score (0-1) for smooth transitions
const lodScore = computed(() => {
const zoom = zoomRef.value
diff --git a/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts b/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts
index 427cf20c0..8fc82147a 100644
--- a/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts
+++ b/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts
@@ -1,16 +1,17 @@
import { storeToRefs } from 'pinia'
-import { type Ref, computed } from 'vue'
+import { type MaybeRefOrGetter, type Ref, computed, toValue } from 'vue'
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
export const useNodePreviewState = (
- nodeId: string,
+ nodeIdMaybe: MaybeRefOrGetter,
options?: {
isMinimalLOD?: Ref
isCollapsed?: Ref
}
) => {
+ const nodeId = toValue(nodeIdMaybe)
const workflowStore = useWorkflowStore()
const { nodePreviewImages } = storeToRefs(useNodeOutputStore())
diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
index 2387fc59c..59705458c 100644
--- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
+++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
@@ -3,10 +3,9 @@ import { ref } from 'vue'
import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue'
import { t } from '@/i18n'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
-import type {
- IBaseWidget,
- IComboWidget
-} from '@/lib/litegraph/src/types/widgets'
+import { isAssetWidget, isComboWidget } from '@/lib/litegraph/src/litegraph'
+import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
+import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
import { assetService } from '@/platform/assets/services/assetService'
import { useSettingStore } from '@/platform/settings/settingStore'
import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration'
@@ -73,11 +72,29 @@ const addComboWidget = (
const currentValue = getDefaultValue(inputSpec)
const displayLabel = currentValue ?? t('widgets.selectModel')
- const widget = node.addWidget('asset', inputSpec.name, displayLabel, () => {
- console.log(
- `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}`
- )
- })
+ const assetBrowserDialog = useAssetBrowserDialog()
+
+ const widget = node.addWidget(
+ 'asset',
+ inputSpec.name,
+ displayLabel,
+ async () => {
+ if (!isAssetWidget(widget)) {
+ throw new Error(`Expected asset widget but received ${widget.type}`)
+ }
+ await assetBrowserDialog.show({
+ nodeType: node.comfyClass || '',
+ inputName: inputSpec.name,
+ currentValue: widget.value,
+ onAssetSelected: (filename: string) => {
+ const oldValue = widget.value
+ widget.value = filename
+ // Using onWidgetChanged prevents a callback race where asset selection could reopen the dialog
+ node.onWidgetChanged?.(widget.name, filename, oldValue, widget)
+ }
+ })
+ }
+ )
return widget
}
@@ -96,11 +113,14 @@ const addComboWidget = (
)
if (inputSpec.remote) {
+ if (!isComboWidget(widget)) {
+ throw new Error(`Expected combo widget but received ${widget.type}`)
+ }
const remoteWidget = useRemoteWidget({
remoteConfig: inputSpec.remote,
defaultValue,
node,
- widget: widget as IComboWidget
+ widget
})
if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton()
@@ -116,16 +136,19 @@ const addComboWidget = (
}
if (inputSpec.control_after_generate) {
+ if (!isComboWidget(widget)) {
+ throw new Error(`Expected combo widget but received ${widget.type}`)
+ }
widget.linkedWidgets = addValueControlWidgets(
node,
- widget as IComboWidget,
+ widget,
undefined,
undefined,
transformInputSpecV2ToV1(inputSpec)
)
}
- return widget as IBaseWidget
+ return widget
}
export const useComboWidget = () => {
diff --git a/src/scripts/app.ts b/src/scripts/app.ts
index 43722f9b4..cd085eead 100644
--- a/src/scripts/app.ts
+++ b/src/scripts/app.ts
@@ -596,7 +596,10 @@ export class ComfyApp {
const keybindingStore = useKeybindingStore()
const keybinding = keybindingStore.getKeybinding(keyCombo)
- if (keybinding && keybinding.targetElementId === 'graph-canvas') {
+ if (
+ keybinding &&
+ keybinding.targetElementId === 'graph-canvas-container'
+ ) {
useCommandStore().execute(keybinding.commandId)
this.graph.change()
diff --git a/src/stores/executionStore.ts b/src/stores/executionStore.ts
index cfc7a2dd4..8791ab4e1 100644
--- a/src/stores/executionStore.ts
+++ b/src/stores/executionStore.ts
@@ -43,6 +43,57 @@ interface QueuedPrompt {
workflow?: ComfyWorkflow
}
+const subgraphNodeIdToSubgraph = (id: string, graph: LGraph | Subgraph) => {
+ const node = graph.getNodeById(id)
+ if (node?.isSubgraphNode()) return node.subgraph
+}
+
+/**
+ * Recursively get the subgraph objects for the given subgraph instance IDs
+ * @param currentGraph The current graph
+ * @param subgraphNodeIds The instance IDs
+ * @param subgraphs The subgraphs
+ * @returns The subgraphs that correspond to each of the instance IDs.
+ */
+function getSubgraphsFromInstanceIds(
+ currentGraph: LGraph | Subgraph,
+ subgraphNodeIds: string[],
+ subgraphs: Subgraph[] = []
+): Subgraph[] {
+ // Last segment is the node portion; nothing to do.
+ if (subgraphNodeIds.length === 1) return subgraphs
+
+ const currentPart = subgraphNodeIds.shift()
+ if (currentPart === undefined) return subgraphs
+
+ const subgraph = subgraphNodeIdToSubgraph(currentPart, currentGraph)
+ if (!subgraph) throw new Error(`Subgraph not found: ${currentPart}`)
+
+ subgraphs.push(subgraph)
+ return getSubgraphsFromInstanceIds(subgraph, subgraphNodeIds, subgraphs)
+}
+
+/**
+ * Convert execution context node IDs to NodeLocatorIds
+ * @param nodeId The node ID from execution context (could be execution ID)
+ * @returns The NodeLocatorId
+ */
+function executionIdToNodeLocatorId(nodeId: string | number): NodeLocatorId {
+ const nodeIdStr = String(nodeId)
+
+ if (!nodeIdStr.includes(':')) {
+ // It's a top-level node ID
+ return nodeIdStr
+ }
+
+ // It's an execution node ID
+ const parts = nodeIdStr.split(':')
+ const localNodeId = parts[parts.length - 1]
+ const subgraphs = getSubgraphsFromInstanceIds(app.graph, parts)
+ const nodeLocatorId = createNodeLocatorId(subgraphs.at(-1)!.id, localNodeId)
+ return nodeLocatorId
+}
+
export const useExecutionStore = defineStore('execution', () => {
const workflowStore = useWorkflowStore()
const canvasStore = useCanvasStore()
@@ -55,29 +106,6 @@ export const useExecutionStore = defineStore('execution', () => {
// This is the progress of all nodes in the currently executing workflow
const nodeProgressStates = ref>({})
- /**
- * Convert execution context node IDs to NodeLocatorIds
- * @param nodeId The node ID from execution context (could be execution ID)
- * @returns The NodeLocatorId
- */
- const executionIdToNodeLocatorId = (
- nodeId: string | number
- ): NodeLocatorId => {
- const nodeIdStr = String(nodeId)
-
- if (!nodeIdStr.includes(':')) {
- // It's a top-level node ID
- return nodeIdStr
- }
-
- // It's an execution node ID
- const parts = nodeIdStr.split(':')
- const localNodeId = parts[parts.length - 1]
- const subgraphs = getSubgraphsFromInstanceIds(app.graph, parts)
- const nodeLocatorId = createNodeLocatorId(subgraphs.at(-1)!.id, localNodeId)
- return nodeLocatorId
- }
-
const mergeExecutionProgressStates = (
currentState: NodeProgressState | undefined,
newState: NodeProgressState
@@ -139,9 +167,13 @@ export const useExecutionStore = defineStore('execution', () => {
// @deprecated For backward compatibility - stores the primary executing node ID
const executingNodeId = computed(() => {
- return executingNodeIds.value.length > 0 ? executingNodeIds.value[0] : null
+ return executingNodeIds.value[0] ?? null
})
+ const uniqueExecutingNodeIdStrings = computed(
+ () => new Set(executingNodeIds.value.map(String))
+ )
+
// For backward compatibility - returns the primary executing node
const executingNode = computed(() => {
if (!executingNodeId.value) return null
@@ -159,36 +191,6 @@ export const useExecutionStore = defineStore('execution', () => {
)
})
- const subgraphNodeIdToSubgraph = (id: string, graph: LGraph | Subgraph) => {
- const node = graph.getNodeById(id)
- if (node?.isSubgraphNode()) return node.subgraph
- }
-
- /**
- * Recursively get the subgraph objects for the given subgraph instance IDs
- * @param currentGraph The current graph
- * @param subgraphNodeIds The instance IDs
- * @param subgraphs The subgraphs
- * @returns The subgraphs that correspond to each of the instance IDs.
- */
- const getSubgraphsFromInstanceIds = (
- currentGraph: LGraph | Subgraph,
- subgraphNodeIds: string[],
- subgraphs: Subgraph[] = []
- ): Subgraph[] => {
- // Last segment is the node portion; nothing to do.
- if (subgraphNodeIds.length === 1) return subgraphs
-
- const currentPart = subgraphNodeIds.shift()
- if (currentPart === undefined) return subgraphs
-
- const subgraph = subgraphNodeIdToSubgraph(currentPart, currentGraph)
- if (!subgraph) throw new Error(`Subgraph not found: ${currentPart}`)
-
- subgraphs.push(subgraph)
- return getSubgraphsFromInstanceIds(subgraph, subgraphNodeIds, subgraphs)
- }
-
// This is the progress of the currently executing node (for backward compatibility)
const _executingNodeProgress = ref(null)
const executingNodeProgress = computed(() =>
@@ -423,66 +425,25 @@ export const useExecutionStore = defineStore('execution', () => {
return {
isIdle,
clientId,
- /**
- * The id of the prompt that is currently being executed
- */
activePromptId,
- /**
- * The queued prompts
- */
queuedPrompts,
- /**
- * The node errors from the previous execution.
- */
lastNodeErrors,
- /**
- * The error from the previous execution.
- */
lastExecutionError,
- /**
- * Local node ID for the most recent execution error.
- */
lastExecutionErrorNodeId,
- /**
- * The id of the node that is currently being executed (backward compatibility)
- */
executingNodeId,
- /**
- * The list of all nodes that are currently executing
- */
executingNodeIds,
- /**
- * The prompt that is currently being executed
- */
activePrompt,
- /**
- * The total number of nodes to execute
- */
totalNodesToExecute,
- /**
- * The number of nodes that have been executed
- */
nodesExecuted,
- /**
- * The progress of the execution
- */
executionProgress,
- /**
- * The node that is currently being executed (backward compatibility)
- */
executingNode,
- /**
- * The progress of the executing node (backward compatibility)
- */
executingNodeProgress,
- /**
- * All node progress states from progress_state events
- */
nodeProgressStates,
nodeLocationProgressStates,
bindExecutionEvents,
unbindExecutionEvents,
storePrompt,
+ uniqueExecutingNodeIdStrings,
// Raw executing progress data for backward compatibility in ComfyApp.
_executingNodeProgress,
// NodeLocatorId conversion helpers
diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts
index f6c15e91a..4f7925294 100644
--- a/src/stores/modelToNodeStore.ts
+++ b/src/stores/modelToNodeStore.ts
@@ -33,12 +33,43 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
)
})
+ /** Internal computed for efficient reverse lookup: nodeType -> category */
+ const nodeTypeToCategory = computed(() => {
+ const lookup: Record = {}
+ for (const [category, providers] of Object.entries(modelToNodeMap.value)) {
+ for (const provider of providers) {
+ // Only store the first category for each node type (matches current assetService behavior)
+ if (!lookup[provider.nodeDef.name]) {
+ lookup[provider.nodeDef.name] = category
+ }
+ }
+ }
+ return lookup
+ })
+
/** Get set of all registered node types for efficient lookup */
function getRegisteredNodeTypes(): Set {
registerDefaults()
return registeredNodeTypes.value
}
+ /**
+ * Get the category for a given node type.
+ * Performs efficient O(1) lookup using cached reverse map.
+ * @param nodeType The node type name to find the category for
+ * @returns The category name, or undefined if not found
+ */
+ function getCategoryForNodeType(nodeType: string): string | undefined {
+ registerDefaults()
+
+ // Handle invalid input gracefully
+ if (!nodeType || typeof nodeType !== 'string') {
+ return undefined
+ }
+
+ return nodeTypeToCategory.value[nodeType]
+ }
+
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
@@ -109,6 +140,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
return {
modelToNodeMap,
getRegisteredNodeTypes,
+ getCategoryForNodeType,
getNodeProvider,
getAllNodeProviders,
registerNodeProvider,
diff --git a/src/types/litegraph-augmentation.d.ts b/src/types/litegraph-augmentation.d.ts
index d404ff88f..0f5dee17d 100644
--- a/src/types/litegraph-augmentation.d.ts
+++ b/src/types/litegraph-augmentation.d.ts
@@ -82,7 +82,7 @@ declare module '@/lib/litegraph/src/litegraph' {
}
// Add interface augmentations into the class itself
- // eslint-disable-next-line @typescript-eslint/no-empty-object-type
+
interface BaseWidget extends IBaseWidget {}
interface LGraphNode {
diff --git a/src/views/GraphView.vue b/src/views/GraphView.vue
index a23d4fecf..bbfca6ea0 100644
--- a/src/views/GraphView.vue
+++ b/src/views/GraphView.vue
@@ -33,6 +33,7 @@ import {
} from 'vue'
import { useI18n } from 'vue-i18n'
+import { runWhenGlobalIdle } from '@/base/common/async'
import MenuHamburger from '@/components/MenuHamburger.vue'
import UnloadWindowConfirmDialog from '@/components/dialog/UnloadWindowConfirmDialog.vue'
import GraphCanvas from '@/components/graph/GraphCanvas.vue'
@@ -253,33 +254,30 @@ void nextTick(() => {
})
const onGraphReady = () => {
- requestIdleCallback(
- () => {
- // Setting values now available after comfyApp.setup.
- // Load keybindings.
- wrapWithErrorHandling(useKeybindingService().registerUserKeybindings)()
+ runWhenGlobalIdle(() => {
+ // Setting values now available after comfyApp.setup.
+ // Load keybindings.
+ wrapWithErrorHandling(useKeybindingService().registerUserKeybindings)()
- // Load server config
- wrapWithErrorHandling(useServerConfigStore().loadServerConfig)(
- SERVER_CONFIG_ITEMS,
- settingStore.get('Comfy.Server.ServerConfigValues')
- )
+ // Load server config
+ wrapWithErrorHandling(useServerConfigStore().loadServerConfig)(
+ SERVER_CONFIG_ITEMS,
+ settingStore.get('Comfy.Server.ServerConfigValues')
+ )
- // Load model folders
- void wrapWithErrorHandlingAsync(useModelStore().loadModelFolders)()
+ // Load model folders
+ void wrapWithErrorHandlingAsync(useModelStore().loadModelFolders)()
- // Non-blocking load of node frequencies
- void wrapWithErrorHandlingAsync(
- useNodeFrequencyStore().loadNodeFrequencies
- )()
+ // Non-blocking load of node frequencies
+ void wrapWithErrorHandlingAsync(
+ useNodeFrequencyStore().loadNodeFrequencies
+ )()
- // Node defs now available after comfyApp.setup.
- // Explicitly initialize nodeSearchService to avoid indexing delay when
- // node search is triggered
- useNodeDefStore().nodeSearchService.searchNode('')
- },
- { timeout: 1000 }
- )
+ // Node defs now available after comfyApp.setup.
+ // Explicitly initialize nodeSearchService to avoid indexing delay when
+ // node search is triggered
+ useNodeDefStore().nodeSearchService.searchNode('')
+ }, 1000)
}
diff --git a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
index d7d4f74dc..bef33733b 100644
--- a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
+++ b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
@@ -1,10 +1,33 @@
-import { describe, expect, it } from 'vitest'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
import { nextTick } from 'vue'
import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
+import { assetService } from '@/platform/assets/services/assetService'
+
+vi.mock('@/platform/assets/services/assetService', () => ({
+ assetService: {
+ getAssetDetails: vi.fn()
+ }
+}))
+
+vi.mock('@/i18n', () => ({
+ t: (key: string) => {
+ const translations: Record = {
+ 'assetBrowser.allModels': 'All Models',
+ 'assetBrowser.assets': 'Assets',
+ 'assetBrowser.unknown': 'unknown'
+ }
+ return translations[key] || key
+ },
+ d: (date: Date) => date.toLocaleDateString()
+}))
describe('useAssetBrowser', () => {
+ beforeEach(() => {
+ vi.restoreAllMocks()
+ })
+
// Test fixtures - minimal data focused on functionality being tested
const createApiAsset = (overrides: Partial = {}): AssetItem => ({
id: 'test-id',
@@ -26,8 +49,8 @@ describe('useAssetBrowser', () => {
user_metadata: { description: 'Test model' }
})
- const { transformAssetForDisplay } = useAssetBrowser([apiAsset])
- const result = transformAssetForDisplay(apiAsset)
+ const { filteredAssets } = useAssetBrowser([apiAsset])
+ const result = filteredAssets.value[0] // Get the transformed asset from filteredAssets
// Preserves API properties
expect(result.id).toBe(apiAsset.id)
@@ -49,15 +72,13 @@ describe('useAssetBrowser', () => {
user_metadata: undefined
})
- const { transformAssetForDisplay } = useAssetBrowser([apiAsset])
- const result = transformAssetForDisplay(apiAsset)
+ const { filteredAssets } = useAssetBrowser([apiAsset])
+ const result = filteredAssets.value[0]
expect(result.description).toBe('loras model')
})
it('formats various file sizes correctly', () => {
- const { transformAssetForDisplay } = useAssetBrowser([])
-
const testCases = [
{ size: 512, expected: '512 B' },
{ size: 1536, expected: '1.5 KB' },
@@ -67,7 +88,8 @@ describe('useAssetBrowser', () => {
testCases.forEach(({ size, expected }) => {
const asset = createApiAsset({ size })
- const result = transformAssetForDisplay(asset)
+ const { filteredAssets } = useAssetBrowser([asset])
+ const result = filteredAssets.value[0]
expect(result.formattedSize).toBe(expected)
})
})
@@ -236,18 +258,182 @@ describe('useAssetBrowser', () => {
})
})
- describe('Asset Selection', () => {
- it('returns selected asset UUID for efficient handling', () => {
+ describe('Async Asset Selection with Detail Fetching', () => {
+ it('should fetch asset details and call onSelect with filename when provided', async () => {
+ const onSelectSpy = vi.fn()
const asset = createApiAsset({
- id: 'test-uuid-123',
- name: 'selected_model.safetensors'
+ id: 'asset-123',
+ name: 'test-model.safetensors'
})
- const { selectAsset, transformAssetForDisplay } = useAssetBrowser([asset])
- const displayAsset = transformAssetForDisplay(asset)
- const result = selectAsset(displayAsset)
+ const detailAsset = createApiAsset({
+ id: 'asset-123',
+ name: 'test-model.safetensors',
+ user_metadata: { filename: 'checkpoints/test-model.safetensors' }
+ })
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
- expect(result).toBe('test-uuid-123')
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123')
+ expect(onSelectSpy).toHaveBeenCalledWith(
+ 'checkpoints/test-model.safetensors'
+ )
+ })
+
+ it('should handle missing user_metadata.filename as error', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+ const asset = createApiAsset({ id: 'asset-456' })
+
+ const detailAsset = createApiAsset({
+ id: 'asset-456',
+ user_metadata: { filename: '' } // Invalid empty filename
+ })
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456')
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Filename cannot be empty'
+ })
+ ]),
+ 'for asset:',
+ 'asset-456'
+ )
+ })
+
+ it('should handle API errors gracefully', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+ const asset = createApiAsset({ id: 'asset-789' })
+
+ const apiError = new Error('API Error')
+ vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError)
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789')
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ expect.stringContaining('Failed to fetch asset details for asset-789'),
+ apiError
+ )
+ })
+
+ it('should not fetch details when no callback provided', async () => {
+ const asset = createApiAsset({ id: 'asset-no-callback' })
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id)
+
+ expect(assetService.getAssetDetails).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('Filename Validation Security', () => {
+ const createValidationTest = (filename: string) => {
+ const testAsset = createApiAsset({ id: 'validation-test' })
+ const detailAsset = createApiAsset({
+ id: 'validation-test',
+ user_metadata: { filename }
+ })
+ return { testAsset, detailAsset }
+ }
+
+ it('accepts valid file paths with forward slashes', async () => {
+ const onSelectSpy = vi.fn()
+ const { testAsset, detailAsset } = createValidationTest(
+ 'models/checkpoints/v1/test-model.safetensors'
+ )
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).toHaveBeenCalledWith(
+ 'models/checkpoints/v1/test-model.safetensors'
+ )
+ })
+
+ it('rejects directory traversal attacks', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+
+ const maliciousPaths = [
+ '../malicious-model.safetensors',
+ 'models/../../../etc/passwd',
+ '/etc/passwd'
+ ]
+
+ for (const path of maliciousPaths) {
+ const { testAsset, detailAsset } = createValidationTest(path)
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Path must not start with / or contain ..'
+ })
+ ]),
+ 'for asset:',
+ 'validation-test'
+ )
+ }
+ })
+
+ it('rejects invalid filename characters', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+
+ const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|']
+
+ for (const char of invalidChars) {
+ const { testAsset, detailAsset } = createValidationTest(
+ `bad${char}filename.safetensors`
+ )
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Invalid filename characters'
+ })
+ ]),
+ 'for asset:',
+ 'validation-test'
+ )
+ }
})
})
diff --git a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
index fefeeceac..102aa7a18 100644
--- a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
+++ b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
@@ -6,11 +6,18 @@ import { useDialogStore } from '@/stores/dialogStore'
// Mock the dialog store
vi.mock('@/stores/dialogStore')
+// Mock the asset service
+vi.mock('@/platform/assets/services/assetService', () => ({
+ assetService: {
+ getAssetsForNodeType: vi.fn().mockResolvedValue([])
+ }
+}))
+
// Test factory functions
interface AssetBrowserProps {
nodeType: string
inputName: string
- onAssetSelected?: ReturnType
+ onAssetSelected?: (filename: string) => void
}
function createAssetBrowserProps(
@@ -25,7 +32,7 @@ function createAssetBrowserProps(
describe('useAssetBrowserDialog', () => {
describe('Asset Selection Flow', () => {
- it('auto-closes dialog when asset is selected', () => {
+ it('auto-closes dialog when asset is selected', async () => {
// Create fresh mocks for this test
const mockShowDialog = vi.fn()
const mockCloseDialog = vi.fn()
@@ -41,7 +48,7 @@ describe('useAssetBrowserDialog', () => {
const onAssetSelected = vi.fn()
const props = createAssetBrowserProps({ onAssetSelected })
- assetBrowserDialog.show(props)
+ await assetBrowserDialog.show(props)
// Get the onSelect handler that was passed to the dialog
const dialogCall = mockShowDialog.mock.calls[0][0]
@@ -50,14 +57,14 @@ describe('useAssetBrowserDialog', () => {
// Simulate asset selection
onSelectHandler('selected-asset-path')
- // Should call the original callback and close dialog
+ // Should call the original callback and trigger hide animation
expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path')
expect(mockCloseDialog).toHaveBeenCalledWith({
key: 'global-asset-browser'
})
})
- it('closes dialog when close handler is called', () => {
+ it('closes dialog when close handler is called', async () => {
// Create fresh mocks for this test
const mockShowDialog = vi.fn()
const mockCloseDialog = vi.fn()
@@ -72,7 +79,7 @@ describe('useAssetBrowserDialog', () => {
const assetBrowserDialog = useAssetBrowserDialog()
const props = createAssetBrowserProps()
- assetBrowserDialog.show(props)
+ await assetBrowserDialog.show(props)
// Get the onClose handler that was passed to the dialog
const dialogCall = mockShowDialog.mock.calls[0][0]
diff --git a/tests-ui/tests/composables/node/useNodePricing.test.ts b/tests-ui/tests/composables/node/useNodePricing.test.ts
index 6cd76cb75..32b18ed68 100644
--- a/tests-ui/tests/composables/node/useNodePricing.test.ts
+++ b/tests-ui/tests/composables/node/useNodePricing.test.ts
@@ -1894,4 +1894,159 @@ describe('useNodePricing', () => {
expect(getNodeDisplayPrice(missingDuration)).toBe('Token-based')
})
})
+
+ describe('dynamic pricing - WanTextToVideoApi', () => {
+ it('should return $1.50 for 10s at 1080p', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: '10' },
+ { name: 'size', value: '1080p: 4:3 (1632x1248)' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$1.50/Run') // 0.15 * 10
+ })
+
+ it('should return $0.50 for 5s at 720p', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: 5 },
+ { name: 'size', value: '720p: 16:9 (1280x720)' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.50/Run') // 0.10 * 5
+ })
+
+ it('should return $0.15 for 3s at 480p', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: '3' },
+ { name: 'size', value: '480p: 1:1 (624x624)' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.15/Run') // 0.05 * 3
+ })
+
+ it('should fall back when widgets are missing', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const missingBoth = createMockNode('WanTextToVideoApi', [])
+ const missingSize = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: '5' }
+ ])
+ const missingDuration = createMockNode('WanTextToVideoApi', [
+ { name: 'size', value: '1080p' }
+ ])
+
+ expect(getNodeDisplayPrice(missingBoth)).toBe('$0.05-0.15/second')
+ expect(getNodeDisplayPrice(missingSize)).toBe('$0.05-0.15/second')
+ expect(getNodeDisplayPrice(missingDuration)).toBe('$0.05-0.15/second')
+ })
+
+ it('should fall back on invalid duration', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: 'invalid' },
+ { name: 'size', value: '1080p' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.05-0.15/second')
+ })
+
+ it('should fall back on unknown resolution', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanTextToVideoApi', [
+ { name: 'duration', value: '10' },
+ { name: 'size', value: '2K' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.05-0.15/second')
+ })
+ })
+
+ describe('dynamic pricing - WanImageToVideoApi', () => {
+ it('should return $0.80 for 8s at 720p', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: 8 },
+ { name: 'resolution', value: '720p' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.80/Run') // 0.10 * 8
+ })
+
+ it('should return $0.60 for 12s at 480P', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: '12' },
+ { name: 'resolution', value: '480P' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.60/Run') // 0.05 * 12
+ })
+
+ it('should return $1.50 for 10s at 1080p', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: '10' },
+ { name: 'resolution', value: '1080p' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$1.50/Run') // 0.15 * 10
+ })
+
+ it('should handle "5s" string duration at 1080P', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: '5s' },
+ { name: 'resolution', value: '1080P' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.75/Run') // 0.15 * 5
+ })
+
+ it('should fall back when widgets are missing', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const missingBoth = createMockNode('WanImageToVideoApi', [])
+ const missingRes = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: '5' }
+ ])
+ const missingDuration = createMockNode('WanImageToVideoApi', [
+ { name: 'resolution', value: '1080p' }
+ ])
+
+ expect(getNodeDisplayPrice(missingBoth)).toBe('$0.05-0.15/second')
+ expect(getNodeDisplayPrice(missingRes)).toBe('$0.05-0.15/second')
+ expect(getNodeDisplayPrice(missingDuration)).toBe('$0.05-0.15/second')
+ })
+
+ it('should fall back on invalid duration', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: 'invalid' },
+ { name: 'resolution', value: '720p' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.05-0.15/second')
+ })
+
+ it('should fall back on unknown resolution', () => {
+ const { getNodeDisplayPrice } = useNodePricing()
+ const node = createMockNode('WanImageToVideoApi', [
+ { name: 'duration', value: '10' },
+ { name: 'resolution', value: 'weird-res' }
+ ])
+
+ const price = getNodeDisplayPrice(node)
+ expect(price).toBe('$0.05-0.15/second')
+ })
+ })
})
diff --git a/tests-ui/tests/performance/transformPerformance.test.ts b/tests-ui/tests/performance/transformPerformance.test.ts
index e9f995e97..1f2fb83f7 100644
--- a/tests-ui/tests/performance/transformPerformance.test.ts
+++ b/tests-ui/tests/performance/transformPerformance.test.ts
@@ -14,7 +14,7 @@ const createMockCanvasContext = () => ({
const isCI = Boolean(process.env.CI)
const describeIfNotCI = isCI ? describe.skip : describe
-describeIfNotCI('Transform Performance', () => {
+describeIfNotCI.skip('Transform Performance', () => {
let transformState: ReturnType
let mockCanvas: any
diff --git a/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts b/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts
index 07c0a3081..0615e9b9a 100644
--- a/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts
+++ b/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts
@@ -1,14 +1,38 @@
import { createTestingPinia } from '@pinia/testing'
import { mount } from '@vue/test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
-import { computed, ref } from 'vue'
+import { computed, toValue } from 'vue'
+import type { ComponentProps } from 'vue-component-type-helpers'
import { createI18n } from 'vue-i18n'
import type { VueNodeData } from '@/composables/graph/useGraphNodeManager'
-import { SelectedNodeIdsKey } from '@/renderer/core/canvas/injectionKeys'
import LGraphNode from '@/renderer/extensions/vueNodes/components/LGraphNode.vue'
+import { useNodeEventHandlers } from '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers'
import { useVueElementTracking } from '@/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking'
-import { useNodeExecutionState } from '@/renderer/extensions/vueNodes/execution/useNodeExecutionState'
+
+const mockData = vi.hoisted(() => ({
+ mockNodeIds: new Set(),
+ mockExecuting: false
+}))
+
+vi.mock('@/renderer/core/canvas/canvasStore', () => {
+ const getCanvas = vi.fn()
+ const useCanvasStore = () => ({
+ getCanvas,
+ selectedNodeIds: computed(() => mockData.mockNodeIds)
+ })
+ return {
+ useCanvasStore
+ }
+})
+
+vi.mock(
+ '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers',
+ () => {
+ const handleNodeSelect = vi.fn()
+ return { useNodeEventHandlers: () => ({ handleNodeSelect }) }
+ }
+)
vi.mock(
'@/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking',
@@ -47,7 +71,7 @@ vi.mock(
'@/renderer/extensions/vueNodes/execution/useNodeExecutionState',
() => ({
useNodeExecutionState: vi.fn(() => ({
- executing: computed(() => false),
+ executing: computed(() => mockData.mockExecuting),
progress: computed(() => undefined),
progressPercentage: computed(() => undefined),
progressState: computed(() => undefined as any),
@@ -72,61 +96,56 @@ const i18n = createI18n({
}
}
})
+function mountLGraphNode(props: ComponentProps) {
+ return mount(LGraphNode, {
+ props,
+ global: {
+ plugins: [
+ createTestingPinia({
+ createSpy: vi.fn
+ }),
+ i18n
+ ],
+ stubs: {
+ NodeHeader: true,
+ NodeSlots: true,
+ NodeWidgets: true,
+ NodeContent: true,
+ SlotConnectionDot: true
+ }
+ }
+ })
+}
+const mockNodeData: VueNodeData = {
+ id: 'test-node-123',
+ title: 'Test Node',
+ type: 'TestNode',
+ mode: 0,
+ flags: {},
+ inputs: [],
+ outputs: [],
+ widgets: [],
+ selected: false,
+ executing: false
+}
describe('LGraphNode', () => {
- const mockNodeData: VueNodeData = {
- id: 'test-node-123',
- title: 'Test Node',
- type: 'TestNode',
- mode: 0,
- flags: {},
- inputs: [],
- outputs: [],
- widgets: [],
- selected: false,
- executing: false
- }
-
- const mountLGraphNode = (props: any, selectedNodeIds = new Set()) => {
- return mount(LGraphNode, {
- props,
- global: {
- plugins: [
- createTestingPinia({
- createSpy: vi.fn
- }),
- i18n
- ],
- provide: {
- [SelectedNodeIdsKey as symbol]: ref(selectedNodeIds)
- },
- stubs: {
- NodeHeader: true,
- NodeSlots: true,
- NodeWidgets: true,
- NodeContent: true,
- SlotConnectionDot: true
- }
- }
- })
- }
-
beforeEach(() => {
- vi.clearAllMocks()
- // Reset to default mock
- vi.mocked(useNodeExecutionState).mockReturnValue({
- executing: computed(() => false),
- progress: computed(() => undefined),
- progressPercentage: computed(() => undefined),
- progressState: computed(() => undefined as any),
- executionState: computed(() => 'idle' as const)
- })
+ vi.resetAllMocks()
+ mockData.mockNodeIds = new Set()
+ mockData.mockExecuting = false
})
it('should call resize tracking composable with node ID', () => {
mountLGraphNode({ nodeData: mockNodeData })
- expect(useVueElementTracking).toHaveBeenCalledWith('test-node-123', 'node')
+ expect(useVueElementTracking).toHaveBeenCalledWith(
+ expect.any(Function),
+ 'node'
+ )
+ const idArg = vi.mocked(useVueElementTracking).mock.calls[0]?.[0]
+ const id = toValue(idArg)
+ expect(id).toEqual('test-node-123')
})
it('should render with data-node-id attribute', () => {
@@ -146,9 +165,6 @@ describe('LGraphNode', () => {
}),
i18n
],
- provide: {
- [SelectedNodeIdsKey as symbol]: ref(new Set())
- },
stubs: {
NodeSlots: true,
NodeWidgets: true,
@@ -162,24 +178,15 @@ describe('LGraphNode', () => {
})
it('should apply selected styling when selected prop is true', () => {
- const wrapper = mountLGraphNode(
- { nodeData: mockNodeData, selected: true },
- new Set(['test-node-123'])
- )
+ mockData.mockNodeIds = new Set(['test-node-123'])
+ const wrapper = mountLGraphNode({ nodeData: mockNodeData })
expect(wrapper.classes()).toContain('outline-2')
expect(wrapper.classes()).toContain('outline-black')
expect(wrapper.classes()).toContain('dark-theme:outline-white')
})
it('should apply executing animation when executing prop is true', () => {
- // Mock the execution state to return executing: true
- vi.mocked(useNodeExecutionState).mockReturnValue({
- executing: computed(() => true),
- progress: computed(() => undefined),
- progressPercentage: computed(() => undefined),
- progressState: computed(() => undefined as any),
- executionState: computed(() => 'running' as const)
- })
+ mockData.mockExecuting = true
const wrapper = mountLGraphNode({ nodeData: mockNodeData })
@@ -187,12 +194,16 @@ describe('LGraphNode', () => {
})
it('should emit node-click event on pointer up', async () => {
+ const { handleNodeSelect } = useNodeEventHandlers()
const wrapper = mountLGraphNode({ nodeData: mockNodeData })
await wrapper.trigger('pointerup')
- expect(wrapper.emitted('node-click')).toHaveLength(1)
- expect(wrapper.emitted('node-click')?.[0]).toHaveLength(3)
- expect(wrapper.emitted('node-click')?.[0][1]).toEqual(mockNodeData)
+ expect(handleNodeSelect).toHaveBeenCalledOnce()
+ expect(handleNodeSelect).toHaveBeenCalledWith(
+ expect.any(PointerEvent),
+ mockNodeData,
+ expect.any(Boolean)
+ )
})
})
diff --git a/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts
index e2a4bd920..dd08457eb 100644
--- a/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts
+++ b/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts
@@ -1,98 +1,82 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
-import { computed, ref } from 'vue'
+import { computed, shallowRef } from 'vue'
-import type { VueNodeData } from '@/composables/graph/useGraphNodeManager'
-import type { useGraphNodeManager } from '@/composables/graph/useGraphNodeManager'
-import type { LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph'
+import {
+ type GraphNodeManager,
+ type VueNodeData,
+ useGraphNodeManager
+} from '@/composables/graph/useGraphNodeManager'
+import { useVueNodeLifecycle } from '@/composables/graph/useVueNodeLifecycle'
+import type {
+ LGraph,
+ LGraphCanvas,
+ LGraphNode
+} from '@/lib/litegraph/src/litegraph'
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
-import { useCanvasInteractions } from '@/renderer/core/canvas/useCanvasInteractions'
import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations'
import { useNodeEventHandlers } from '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers'
-vi.mock('@/renderer/core/canvas/canvasStore', () => ({
- useCanvasStore: vi.fn()
-}))
-
-vi.mock('@/renderer/core/canvas/useCanvasInteractions', () => ({
- useCanvasInteractions: vi.fn()
-}))
-
-vi.mock('@/renderer/core/layout/operations/layoutMutations', () => ({
- useLayoutMutations: vi.fn()
-}))
-
-vi.mock('@/composables/graph/useGraphNodeManager', () => ({
- useGraphNodeManager: vi.fn()
-}))
-
-function createMockCanvas(): Pick<
- LGraphCanvas,
- 'select' | 'deselect' | 'deselectAll'
-> {
- return {
+vi.mock('@/renderer/core/canvas/canvasStore', () => {
+ const canvas: Partial = {
select: vi.fn(),
deselect: vi.fn(),
deselectAll: vi.fn()
}
-}
-
-function createMockNode(): Pick {
+ const updateSelectedItems = vi.fn()
return {
+ useCanvasStore: vi.fn(() => ({
+ canvas: canvas as LGraphCanvas,
+ updateSelectedItems,
+ selectedItems: []
+ }))
+ }
+})
+
+vi.mock('@/renderer/core/canvas/useCanvasInteractions', () => ({
+ useCanvasInteractions: vi.fn(() => ({
+ shouldHandleNodePointerEvents: computed(() => true) // Default to allowing pointer events
+ }))
+}))
+
+vi.mock('@/renderer/core/layout/operations/layoutMutations', () => {
+ const setSource = vi.fn()
+ const bringNodeToFront = vi.fn()
+ return {
+ useLayoutMutations: vi.fn(() => ({
+ setSource,
+ bringNodeToFront
+ }))
+ }
+})
+
+vi.mock('@/composables/graph/useGraphNodeManager', () => {
+ const mockNode = {
id: 'node-1',
selected: false,
flags: { pinned: false }
}
-}
-
-function createMockNodeManager(
- node: Pick
-) {
+ const nodeManager = shallowRef({
+ getNode: vi.fn(() => mockNode as Partial as LGraphNode)
+ } as Partial as GraphNodeManager)
return {
- getNode: vi.fn().mockReturnValue(node) as ReturnType<
- typeof useGraphNodeManager
- >['getNode']
+ useGraphNodeManager: vi.fn(() => nodeManager)
}
-}
+})
-function createMockCanvasStore(
- canvas: Pick
-): Pick<
- ReturnType,
- 'canvas' | 'selectedItems' | 'updateSelectedItems'
-> {
+vi.mock('@/composables/graph/useVueNodeLifecycle', () => {
+ const nodeManager = useGraphNodeManager(undefined as unknown as LGraph)
return {
- canvas: canvas as LGraphCanvas,
- selectedItems: [],
- updateSelectedItems: vi.fn()
+ useVueNodeLifecycle: vi.fn(() => ({
+ nodeManager
+ }))
}
-}
-
-function createMockLayoutMutations(): Pick<
- ReturnType,
- 'setSource' | 'bringNodeToFront'
-> {
- return {
- setSource: vi.fn(),
- bringNodeToFront: vi.fn()
- }
-}
-
-function createMockCanvasInteractions(): Pick<
- ReturnType,
- 'shouldHandleNodePointerEvents'
-> {
- return {
- shouldHandleNodePointerEvents: computed(() => true) // Default to allowing pointer events
- }
-}
+})
describe('useNodeEventHandlers', () => {
- let mockCanvas: ReturnType
- let mockNode: ReturnType
- let mockNodeManager: ReturnType
- let mockCanvasStore: ReturnType
- let mockLayoutMutations: ReturnType
- let mockCanvasInteractions: ReturnType
+ const { nodeManager: mockNodeManager } = useVueNodeLifecycle()
+
+ const mockNode = mockNodeManager.value!.getNode('fake_id')
+ const mockLayoutMutations = useLayoutMutations()
const testNodeData: VueNodeData = {
id: 'node-1',
@@ -104,28 +88,13 @@ describe('useNodeEventHandlers', () => {
}
beforeEach(async () => {
- mockNode = createMockNode()
- mockCanvas = createMockCanvas()
- mockNodeManager = createMockNodeManager(mockNode)
- mockCanvasStore = createMockCanvasStore(mockCanvas)
- mockLayoutMutations = createMockLayoutMutations()
- mockCanvasInteractions = createMockCanvasInteractions()
-
- vi.mocked(useCanvasStore).mockReturnValue(
- mockCanvasStore as ReturnType
- )
- vi.mocked(useLayoutMutations).mockReturnValue(
- mockLayoutMutations as ReturnType
- )
- vi.mocked(useCanvasInteractions).mockReturnValue(
- mockCanvasInteractions as ReturnType
- )
+ vi.restoreAllMocks()
})
describe('handleNodeSelect', () => {
it('should select single node on regular click', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
+ const { canvas, updateSelectedItems } = useCanvasStore()
const event = new PointerEvent('pointerdown', {
bubbles: true,
@@ -135,17 +104,17 @@ describe('useNodeEventHandlers', () => {
handleNodeSelect(event, testNodeData, false)
- expect(mockCanvas.deselectAll).toHaveBeenCalledOnce()
- expect(mockCanvas.select).toHaveBeenCalledWith(mockNode)
- expect(mockCanvasStore.updateSelectedItems).toHaveBeenCalledOnce()
+ expect(canvas?.deselectAll).toHaveBeenCalledOnce()
+ expect(canvas?.select).toHaveBeenCalledWith(mockNode)
+ expect(updateSelectedItems).toHaveBeenCalledOnce()
})
it('should toggle selection on ctrl+click', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
+ const { canvas } = useCanvasStore()
// Test selecting unselected node with ctrl
- mockNode.selected = false
+ mockNode!.selected = false
const ctrlClickEvent = new PointerEvent('pointerdown', {
bubbles: true,
@@ -155,16 +124,16 @@ describe('useNodeEventHandlers', () => {
handleNodeSelect(ctrlClickEvent, testNodeData, false)
- expect(mockCanvas.deselectAll).not.toHaveBeenCalled()
- expect(mockCanvas.select).toHaveBeenCalledWith(mockNode)
+ expect(canvas?.deselectAll).not.toHaveBeenCalled()
+ expect(canvas?.select).toHaveBeenCalledWith(mockNode)
})
it('should deselect on ctrl+click of selected node', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
+ const { canvas } = useCanvasStore()
// Test deselecting selected node with ctrl
- mockNode.selected = true
+ mockNode!.selected = true
const ctrlClickEvent = new PointerEvent('pointerdown', {
bubbles: true,
@@ -174,15 +143,15 @@ describe('useNodeEventHandlers', () => {
handleNodeSelect(ctrlClickEvent, testNodeData, false)
- expect(mockCanvas.deselect).toHaveBeenCalledWith(mockNode)
- expect(mockCanvas.select).not.toHaveBeenCalled()
+ expect(canvas?.deselect).toHaveBeenCalledWith(mockNode)
+ expect(canvas?.select).not.toHaveBeenCalled()
})
it('should handle meta key (Cmd) on Mac', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
+ const { canvas } = useCanvasStore()
- mockNode.selected = false
+ mockNode!.selected = false
const metaClickEvent = new PointerEvent('pointerdown', {
bubbles: true,
@@ -192,15 +161,14 @@ describe('useNodeEventHandlers', () => {
handleNodeSelect(metaClickEvent, testNodeData, false)
- expect(mockCanvas.select).toHaveBeenCalledWith(mockNode)
- expect(mockCanvas.deselectAll).not.toHaveBeenCalled()
+ expect(canvas?.select).toHaveBeenCalledWith(mockNode)
+ expect(canvas?.deselectAll).not.toHaveBeenCalled()
})
it('should bring node to front when not pinned', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
- mockNode.flags.pinned = false
+ mockNode!.flags.pinned = false
const event = new PointerEvent('pointerdown')
handleNodeSelect(event, testNodeData, false)
@@ -211,49 +179,14 @@ describe('useNodeEventHandlers', () => {
})
it('should not bring pinned node to front', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
+ const { handleNodeSelect } = useNodeEventHandlers()
- mockNode.flags.pinned = true
+ mockNode!.flags.pinned = true
const event = new PointerEvent('pointerdown')
handleNodeSelect(event, testNodeData, false)
expect(mockLayoutMutations.bringNodeToFront).not.toHaveBeenCalled()
})
-
- it('should handle missing canvas gracefully', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
-
- mockCanvasStore.canvas = null
-
- const event = new PointerEvent('pointerdown')
- expect(() => {
- handleNodeSelect(event, testNodeData, false)
- }).not.toThrow()
-
- expect(mockCanvas.select).not.toHaveBeenCalled()
- })
-
- it('should handle missing node gracefully', () => {
- const nodeManager = ref(mockNodeManager)
- const { handleNodeSelect } = useNodeEventHandlers(nodeManager)
-
- vi.mocked(mockNodeManager.getNode).mockReturnValue(undefined)
-
- const event = new PointerEvent('pointerdown')
- const nodeData = {
- id: 'missing-node',
- title: 'Missing Node',
- type: 'test'
- } as any
-
- expect(() => {
- handleNodeSelect(event, nodeData, false)
- }).not.toThrow()
-
- expect(mockCanvas.select).not.toHaveBeenCalled()
- })
})
})
diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
index 875919ccd..d439d98ca 100644
--- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
+++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
@@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
+import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
import { assetService } from '@/platform/assets/services/assetService'
import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget'
import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
@@ -29,13 +30,25 @@ vi.mock('@/platform/assets/services/assetService', () => ({
}
}))
+vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => {
+ const mockAssetBrowserDialogShow = vi.fn()
+ return {
+ useAssetBrowserDialog: vi.fn(() => ({
+ show: mockAssetBrowserDialogShow
+ }))
+ }
+})
+
// Test factory functions
function createMockWidget(overrides: Partial = {}): IBaseWidget {
+ const mockCallback = vi.fn()
return {
type: 'combo',
options: {},
name: 'testWidget',
value: undefined,
+ callback: mockCallback,
+ y: 0,
...overrides
} as IBaseWidget
}
@@ -45,7 +58,16 @@ function createMockNode(comfyClass = 'TestNode'): LGraphNode {
node.comfyClass = comfyClass
// Spy on the addWidget method
- vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget())
+ vi.spyOn(node, 'addWidget').mockImplementation(
+ (type, name, value, callback) => {
+ const widget = createMockWidget({ type, name, value })
+ // Store the callback function on the widget for testing
+ if (typeof callback === 'function') {
+ widget.callback = callback
+ }
+ return widget
+ }
+ )
return node
}
@@ -61,9 +83,9 @@ function createMockInputSpec(overrides: Partial = {}): InputSpec {
describe('useComboWidget', () => {
beforeEach(() => {
vi.clearAllMocks()
- // Reset to defaults
mockSettingStoreGet.mockReturnValue(false)
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false)
+ vi.mocked(useAssetBrowserDialog).mockClear()
})
it('should handle undefined spec', () => {
diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts
index d96ef765b..7a719c4e4 100644
--- a/tests-ui/tests/services/assetService.test.ts
+++ b/tests-ui/tests/services/assetService.test.ts
@@ -4,6 +4,8 @@ import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { assetService } from '@/platform/assets/services/assetService'
import { api } from '@/scripts/api'
+const mockGetCategoryForNodeType = vi.fn()
+
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: vi.fn(() => ({
getRegisteredNodeTypes: vi.fn(
@@ -14,7 +16,13 @@ vi.mock('@/stores/modelToNodeStore', () => ({
'VAELoader',
'TestNode'
])
- )
+ ),
+ getCategoryForNodeType: mockGetCategoryForNodeType,
+ modelToNodeMap: {
+ checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }],
+ loras: [{ nodeDef: { name: 'LoraLoader' } }],
+ vae: [{ nodeDef: { name: 'VAELoader' } }]
+ }
}))
}))
@@ -210,4 +218,87 @@ describe('assetService', () => {
).toBe(false)
})
})
+
+ describe('getAssetsForNodeType', () => {
+ beforeEach(() => {
+ mockGetCategoryForNodeType.mockClear()
+ })
+
+ it('should return empty array for unregistered node types', async () => {
+ mockGetCategoryForNodeType.mockReturnValue(undefined)
+
+ const result = await assetService.getAssetsForNodeType('UnknownNode')
+
+ expect(mockGetCategoryForNodeType).toHaveBeenCalledWith('UnknownNode')
+ expect(result).toEqual([])
+ })
+
+ it('should use getCategoryForNodeType for efficient category lookup', async () => {
+ mockGetCategoryForNodeType.mockReturnValue('checkpoints')
+ const testAssets = [MOCK_ASSETS.checkpoints]
+ mockApiResponse(testAssets)
+
+ const result = await assetService.getAssetsForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+
+ expect(mockGetCategoryForNodeType).toHaveBeenCalledWith(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toEqual(testAssets)
+
+ // Verify API call includes correct category
+ expect(api.fetchApi).toHaveBeenCalledWith(
+ '/assets?include_tags=models,checkpoints'
+ )
+ })
+
+ it('should return empty array when no category found', async () => {
+ mockGetCategoryForNodeType.mockReturnValue(undefined)
+
+ const result = await assetService.getAssetsForNodeType('TestNode')
+
+ expect(result).toEqual([])
+ expect(api.fetchApi).not.toHaveBeenCalled()
+ })
+
+ it('should handle API errors gracefully', async () => {
+ mockGetCategoryForNodeType.mockReturnValue('loras')
+ mockApiError(500, 'Internal Server Error')
+
+ await expect(
+ assetService.getAssetsForNodeType('LoraLoader')
+ ).rejects.toThrow(
+ 'Unable to load assets for LoraLoader: Server returned 500. Please try again.'
+ )
+ })
+
+ it('should return all assets without filtering for different categories', async () => {
+ // Test checkpoints
+ mockGetCategoryForNodeType.mockReturnValue('checkpoints')
+ const checkpointAssets = [MOCK_ASSETS.checkpoints]
+ mockApiResponse(checkpointAssets)
+
+ let result = await assetService.getAssetsForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toEqual(checkpointAssets)
+
+ // Test loras
+ mockGetCategoryForNodeType.mockReturnValue('loras')
+ const loraAssets = [MOCK_ASSETS.loras]
+ mockApiResponse(loraAssets)
+
+ result = await assetService.getAssetsForNodeType('LoraLoader')
+ expect(result).toEqual(loraAssets)
+
+ // Test vae
+ mockGetCategoryForNodeType.mockReturnValue('vae')
+ const vaeAssets = [MOCK_ASSETS.vae]
+ mockApiResponse(vaeAssets)
+
+ result = await assetService.getAssetsForNodeType('VAELoader')
+ expect(result).toEqual(vaeAssets)
+ })
+ })
})
diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts
index 179c76b9e..b07c34a41 100644
--- a/tests-ui/tests/store/modelToNodeStore.test.ts
+++ b/tests-ui/tests/store/modelToNodeStore.test.ts
@@ -19,7 +19,7 @@ const EXPECTED_DEFAULT_TYPES = [
'gligen'
] as const
-type NodeDefStoreType = typeof import('@/stores/nodeDefStore')
+type NodeDefStoreType = ReturnType
// Create minimal but valid ComfyNodeDefImpl for testing
function createMockNodeDef(name: string): ComfyNodeDefImpl {
@@ -343,6 +343,107 @@ describe('useModelToNodeStore', () => {
})
})
+ describe('getCategoryForNodeType', () => {
+ it('should return category for known node type', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ ).toBe('checkpoints')
+ expect(modelToNodeStore.getCategoryForNodeType('LoraLoader')).toBe(
+ 'loras'
+ )
+ expect(modelToNodeStore.getCategoryForNodeType('VAELoader')).toBe('vae')
+ })
+
+ it('should return undefined for unknown node type', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('NonExistentNode')
+ ).toBeUndefined()
+ expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined()
+ })
+
+ it('should return first category when node type exists in multiple categories', () => {
+ const modelToNodeStore = useModelToNodeStore()
+
+ // Test with a node that exists in the defaults but add our own first
+ // Since defaults register 'StyleModelLoader' in 'style_models',
+ // we verify our custom registrations come after defaults in Object.entries iteration
+ const result = modelToNodeStore.getCategoryForNodeType('StyleModelLoader')
+ expect(result).toBe('style_models') // This proves the method works correctly
+
+ // Now test that custom registrations after defaults also work
+ modelToNodeStore.quickRegister(
+ 'unicorn_styles',
+ 'StyleModelLoader',
+ 'param1'
+ )
+ const result2 =
+ modelToNodeStore.getCategoryForNodeType('StyleModelLoader')
+ // Should still be style_models since it was registered first by defaults
+ expect(result2).toBe('style_models')
+ })
+
+ it('should trigger lazy registration when called before registerDefaults', () => {
+ const modelToNodeStore = useModelToNodeStore()
+
+ const result = modelToNodeStore.getCategoryForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toBe('checkpoints')
+ })
+
+ it('should be performant for repeated lookups', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ // Measure performance without assuming implementation
+ const start = performance.now()
+ for (let i = 0; i < 1000; i++) {
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ }
+ const end = performance.now()
+
+ // Should be fast enough for UI responsiveness
+ expect(end - start).toBeLessThan(10)
+ })
+
+ it('should handle invalid input types gracefully', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ // These should not throw but return undefined
+ expect(
+ modelToNodeStore.getCategoryForNodeType(null as any)
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType(undefined as any)
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType(123 as any)
+ ).toBeUndefined()
+ })
+
+ it('should be case-sensitive for node type matching', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('checkpointloadersimple')
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CHECKPOINTLOADERSIMPLE')
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ ).toBe('checkpoints')
+ })
+ })
+
describe('edge cases', () => {
it('should handle empty string model type', () => {
const modelToNodeStore = useModelToNodeStore()
diff --git a/tsconfig.json b/tsconfig.json
index de5cb4def..97346f56e 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -29,14 +29,15 @@
"rootDir": "./"
},
"include": [
- "src/**/*",
+ ".storybook/**/*",
+ "eslint.config.ts",
+ "global.d.ts",
+ "knip.config.ts",
"src/**/*.vue",
+ "src/**/*",
"src/types/**/*.d.ts",
"tests-ui/**/*",
- "global.d.ts",
- "eslint.config.ts",
"vite.config.mts",
- "knip.config.ts",
- ".storybook/**/*"
+ "vitest.config.ts",
]
}
diff --git a/vite.config.mts b/vite.config.mts
index 0ca062273..25cd730aa 100644
--- a/vite.config.mts
+++ b/vite.config.mts
@@ -31,7 +31,8 @@ export default defineConfig({
ignored: [
'**/coverage/**',
'**/playwright-report/**',
- '**/*.{test,spec}.ts'
+ '**/*.{test,spec}.ts',
+ '*.config.{ts,mts}'
]
},
proxy: {
diff --git a/vitest.config.ts b/vitest.config.ts
index 36fdb1a00..02565497e 100644
--- a/vitest.config.ts
+++ b/vitest.config.ts
@@ -32,7 +32,8 @@ export default defineConfig({
'**/.{idea,git,cache,output,temp}/**',
'**/{karma,rollup,webpack,vite,vitest,jest,ava,babel,nyc,cypress,tsup,build,eslint,prettier}.config.*',
'src/lib/litegraph/test/**'
- ]
+ ],
+ silent: 'passed-only'
},
resolve: {
alias: {