diff --git a/src/platform/assets/utils/createModelNodeFromAsset.test.ts b/src/platform/assets/utils/createModelNodeFromAsset.test.ts index 8594510b6..cdc42dc39 100644 --- a/src/platform/assets/utils/createModelNodeFromAsset.test.ts +++ b/src/platform/assets/utils/createModelNodeFromAsset.test.ts @@ -110,13 +110,19 @@ async function createMockNode(overrides?: { widgets: { value: [widget], writable: true } }) } -function createMockNodeProvider() { +function createMockNodeProvider( + overrides: { + nodeDef?: { name: string; display_name: string } + key?: string + } = {} +) { return { nodeDef: { name: 'CheckpointLoaderSimple', - display_name: 'Load Checkpoint' + display_name: 'Load Checkpoint', + ...overrides.nodeDef }, - key: 'ckpt_name' + key: overrides.key ?? 'ckpt_name' } } /** @@ -270,6 +276,24 @@ describe('createModelNodeFromAsset', () => { expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode) expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled() }) + it('should succeed when provider has empty key (auto-load nodes)', async () => { + const asset = createMockAsset({ + tags: ['models', 'chatterbox/chatterbox_vc'], + user_metadata: { filename: 'chatterbox_vc_model.pt' } + }) + const mockNode = await createMockNode({ hasWidgets: false }) + const nodeProvider = createMockNodeProvider({ + nodeDef: { + name: 'FL_ChatterboxVC', + display_name: 'FL Chatterbox VC' + }, + key: '' + }) + await setupMocks({ createdNode: mockNode, nodeProvider }) + const result = createModelNodeFromAsset(asset) + expect(result.success).toBe(true) + expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode) + }) }) describe('when asset data is incomplete or invalid', () => { beforeEach(() => { diff --git a/src/platform/assets/utils/createModelNodeFromAsset.ts b/src/platform/assets/utils/createModelNodeFromAsset.ts index 7bf8b752b..b60109659 100644 --- a/src/platform/assets/utils/createModelNodeFromAsset.ts +++ b/src/platform/assets/utils/createModelNodeFromAsset.ts @@ -171,26 +171,27 @@ export function createModelNodeFromAsset( } } - const widget = node.widgets?.find((w) => w.name === provider.key) - if (!widget) { - console.error( - `Widget ${provider.key} not found on node ${provider.nodeDef.name}` - ) - return { - success: false, - error: { - code: 'MISSING_WIDGET', - message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`, - assetId: validAsset.id, - details: { widgetName: provider.key, nodeType: provider.nodeDef.name } + // Set widget value if provider specifies a key (some nodes auto-load models without a widget) + if (provider.key) { + const widget = node.widgets?.find((w) => w.name === provider.key) + if (!widget) { + console.error( + `Widget ${provider.key} not found on node ${provider.nodeDef.name}` + ) + return { + success: false, + error: { + code: 'MISSING_WIDGET', + message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`, + assetId: validAsset.id, + details: { widgetName: provider.key, nodeType: provider.nodeDef.name } + } } } + widget.value = filename } - // Set widget value BEFORE adding to graph so the node is created with correct value - widget.value = filename - - // Now add the node to the graph with the correct widget value already set + // Add the node to the graph targetGraph.add(node) return { success: true, value: node } diff --git a/src/stores/modelToNodeStore.test.ts b/src/stores/modelToNodeStore.test.ts index c79af19b5..b9528ca47 100644 --- a/src/stores/modelToNodeStore.test.ts +++ b/src/stores/modelToNodeStore.test.ts @@ -22,7 +22,11 @@ const EXPECTED_DEFAULT_TYPES = [ 'audio_encoders', 'model_patches', 'animatediff_models', - 'animatediff_motion_lora' + 'animatediff_motion_lora', + 'chatterbox/chatterbox', + 'chatterbox/chatterbox_turbo', + 'chatterbox/chatterbox_multilingual', + 'chatterbox/chatterbox_vc' ] as const type NodeDefStoreType = ReturnType @@ -60,7 +64,11 @@ const MOCK_NODE_NAMES = [ 'AudioEncoderLoader', 'ModelPatchLoader', 'ADE_LoadAnimateDiffModel', - 'ADE_AnimateDiffLoRALoader' + 'ADE_AnimateDiffLoRALoader', + 'FL_ChatterboxTTS', + 'FL_ChatterboxTurboTTS', + 'FL_ChatterboxMultilingualTTS', + 'FL_ChatterboxVC' ] as const const mockNodeDefsByName = Object.fromEntries( @@ -134,6 +142,36 @@ describe('useModelToNodeStore', () => { const provider = modelToNodeStore.getNodeProvider('checkpoints') expect(provider).toBeDefined() }) + + it('should fallback to top-level folder for hierarchical model types', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + const provider = modelToNodeStore.getNodeProvider('checkpoints/subfolder') + expect(provider).toBeDefined() + expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple') + }) + + it('should return undefined for hierarchical type with unregistered top-level', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getNodeProvider('UnknownType/subfolder') + ).toBeUndefined() + }) + + it('should return provider for chatterbox nodes with empty key', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + const provider = modelToNodeStore.getNodeProvider( + 'chatterbox/chatterbox_vc' + ) + expect(provider).toBeDefined() + expect(provider?.nodeDef?.name).toBe('FL_ChatterboxVC') + expect(provider?.key).toBe('') + }) }) describe('getAllNodeProviders', () => { @@ -183,6 +221,17 @@ describe('useModelToNodeStore', () => { const providers = modelToNodeStore.getAllNodeProviders('checkpoints') expect(providers.length).toBeGreaterThan(0) }) + + it('should fallback to top-level folder for hierarchical model types', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + const providers = modelToNodeStore.getAllNodeProviders( + 'checkpoints/subfolder' + ) + expect(providers).toHaveLength(2) + expect(providers[0].nodeDef.name).toBe('CheckpointLoaderSimple') + }) }) describe('registerNodeProvider', () => { diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index cc9a911c8..2355da230 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -9,7 +9,7 @@ export class ModelNodeProvider { /** The node definition to use for this model. */ public nodeDef: ComfyNodeDefImpl - /** The node input key for where to inside the model name. */ + /** The node input key for where to insert the model name. */ public key: string constructor(nodeDef: ComfyNodeDefImpl, key: string) { @@ -73,23 +73,45 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return nodeTypeToCategory.value[nodeType] } + /** + * Find providers for modelType with hierarchical fallback. + * Tries exact match first, then falls back to top-level segment (e.g., "parent/child" → "parent"). + * Note: Only falls back one level; "a/b/c" tries "a/b/c" then "a", not "a/b". + */ + function findProvidersWithFallback( + modelType: string + ): ModelNodeProvider[] | undefined { + const exactMatch = modelToNodeMap.value[modelType] + if (exactMatch && exactMatch.length > 0) return exactMatch + + const topLevel = modelType.split('/')[0] + if (topLevel !== modelType) { + const fallback = modelToNodeMap.value[topLevel] + if (fallback && fallback.length > 0) return fallback + } + return undefined + } + /** * Get the node provider for the given model type name. + * Supports hierarchical lookups: if "parent/child" has no match, falls back to "parent". * @param modelType The name of the model type to get the node provider for. * @returns The node provider for the given model type name. */ function getNodeProvider(modelType: string): ModelNodeProvider | undefined { registerDefaults() - return modelToNodeMap.value[modelType]?.[0] + return findProvidersWithFallback(modelType)?.[0] } + /** * Get the list of all valid node providers for the given model type name. + * Supports hierarchical lookups: if "parent/child" has no match, falls back to "parent". * @param modelType The name of the model type to get the node providers for. * @returns The list of all valid node providers for the given model type name. */ function getAllNodeProviders(modelType: string): ModelNodeProvider[] { registerDefaults() - return modelToNodeMap.value[modelType] ?? [] + return findProvidersWithFallback(modelType) ?? [] } /** * Register a node provider for the given model type name. @@ -153,6 +175,17 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { 'ADE_AnimateDiffLoRALoader', 'name' ) + + // Chatterbox TTS nodes: empty key means the node auto-loads models without + // a widget selector (createModelNodeFromAsset skips widget assignment) + quickRegister('chatterbox/chatterbox', 'FL_ChatterboxTTS', '') + quickRegister('chatterbox/chatterbox_turbo', 'FL_ChatterboxTurboTTS', '') + quickRegister( + 'chatterbox/chatterbox_multilingual', + 'FL_ChatterboxMultilingualTTS', + '' + ) + quickRegister('chatterbox/chatterbox_vc', 'FL_ChatterboxVC', '') } return {