mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-21 07:14:11 +00:00
feat: add Chatterbox model support for Cloud asset browser (#8418)
## Summary Adds support for creating Chatterbox TTS nodes when clicking Chatterbox models in the Cloud asset browser. ## Changes ### modelToNodeStore.ts - Add `findProvidersWithFallback()` helper for hierarchical model type lookups (e.g., `parent/child` falls back to `parent`) - Register 4 Chatterbox model directories with empty widget keys: - `chatterbox/chatterbox` → `FL_ChatterboxTTS` - `chatterbox/chatterbox_turbo` → `FL_ChatterboxTurboTTS` - `chatterbox/chatterbox_multilingual` → `FL_ChatterboxMultilingualTTS` - `chatterbox/chatterbox_vc` → `FL_ChatterboxVC` ### createModelNodeFromAsset.ts - Skip widget assignment when `provider.key` is empty (for nodes that auto-load models without a widget selector) ### Tests - Add tests for hierarchical fallback behavior - Add tests for empty widget key (auto-load nodes) - Add Chatterbox node types to mock data ## Notes - Empty `key` convention: Chatterbox nodes auto-load their models and don't have a model selector widget, so we register them with `key: ''` and skip the widget assignment step - Hierarchical fallback only goes one level deep (`a/b/c` → `a`, not `a/b`) ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-8418-feat-add-Chatterbox-model-support-for-Cloud-asset-browser-2f76d73d365081be822bc369b155f099) by [Unito](https://www.unito.io) --------- Co-authored-by: Subagent 5 <subagent@example.com> Co-authored-by: Amp <amp@ampcode.com> Co-authored-by: Alexander Brown <drjkl@comfy.org> Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
@@ -110,13 +110,19 @@ async function createMockNode(overrides?: {
|
||||
widgets: { value: [widget], writable: true }
|
||||
})
|
||||
}
|
||||
function createMockNodeProvider() {
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint'
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: 'ckpt_name'
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
/**
|
||||
@@ -270,6 +276,24 @@ describe('createModelNodeFromAsset', () => {
|
||||
expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should succeed when provider has empty key (auto-load nodes)', async () => {
|
||||
const asset = createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
const nodeProvider = createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
await setupMocks({ createdNode: mockNode, nodeProvider })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
})
|
||||
describe('when asset data is incomplete or invalid', () => {
|
||||
beforeEach(() => {
|
||||
|
||||
@@ -171,26 +171,27 @@ export function createModelNodeFromAsset(
|
||||
}
|
||||
}
|
||||
|
||||
const widget = node.widgets?.find((w) => w.name === provider.key)
|
||||
if (!widget) {
|
||||
console.error(
|
||||
`Widget ${provider.key} not found on node ${provider.nodeDef.name}`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'MISSING_WIDGET',
|
||||
message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { widgetName: provider.key, nodeType: provider.nodeDef.name }
|
||||
// Set widget value if provider specifies a key (some nodes auto-load models without a widget)
|
||||
if (provider.key) {
|
||||
const widget = node.widgets?.find((w) => w.name === provider.key)
|
||||
if (!widget) {
|
||||
console.error(
|
||||
`Widget ${provider.key} not found on node ${provider.nodeDef.name}`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'MISSING_WIDGET',
|
||||
message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { widgetName: provider.key, nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
widget.value = filename
|
||||
}
|
||||
|
||||
// Set widget value BEFORE adding to graph so the node is created with correct value
|
||||
widget.value = filename
|
||||
|
||||
// Now add the node to the graph with the correct widget value already set
|
||||
// Add the node to the graph
|
||||
targetGraph.add(node)
|
||||
|
||||
return { success: true, value: node }
|
||||
|
||||
@@ -23,7 +23,11 @@ const EXPECTED_DEFAULT_TYPES = [
|
||||
'audio_encoders',
|
||||
'model_patches',
|
||||
'animatediff_models',
|
||||
'animatediff_motion_lora'
|
||||
'animatediff_motion_lora',
|
||||
'chatterbox/chatterbox',
|
||||
'chatterbox/chatterbox_turbo',
|
||||
'chatterbox/chatterbox_multilingual',
|
||||
'chatterbox/chatterbox_vc'
|
||||
] as const
|
||||
|
||||
type NodeDefStoreType = ReturnType<typeof useNodeDefStore>
|
||||
@@ -61,7 +65,11 @@ const MOCK_NODE_NAMES = [
|
||||
'AudioEncoderLoader',
|
||||
'ModelPatchLoader',
|
||||
'ADE_LoadAnimateDiffModel',
|
||||
'ADE_AnimateDiffLoRALoader'
|
||||
'ADE_AnimateDiffLoRALoader',
|
||||
'FL_ChatterboxTTS',
|
||||
'FL_ChatterboxTurboTTS',
|
||||
'FL_ChatterboxMultilingualTTS',
|
||||
'FL_ChatterboxVC'
|
||||
] as const
|
||||
|
||||
const mockNodeDefsByName = Object.fromEntries(
|
||||
@@ -135,6 +143,36 @@ describe('useModelToNodeStore', () => {
|
||||
const provider = modelToNodeStore.getNodeProvider('checkpoints')
|
||||
expect(provider).toBeDefined()
|
||||
})
|
||||
|
||||
it('should fallback to top-level folder for hierarchical model types', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
const provider = modelToNodeStore.getNodeProvider('checkpoints/subfolder')
|
||||
expect(provider).toBeDefined()
|
||||
expect(provider?.nodeDef?.name).toBe('CheckpointLoaderSimple')
|
||||
})
|
||||
|
||||
it('should return undefined for hierarchical type with unregistered top-level', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
expect(
|
||||
modelToNodeStore.getNodeProvider('UnknownType/subfolder')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return provider for chatterbox nodes with empty key', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
const provider = modelToNodeStore.getNodeProvider(
|
||||
'chatterbox/chatterbox_vc'
|
||||
)
|
||||
expect(provider).toBeDefined()
|
||||
expect(provider?.nodeDef?.name).toBe('FL_ChatterboxVC')
|
||||
expect(provider?.key).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAllNodeProviders', () => {
|
||||
@@ -184,6 +222,17 @@ describe('useModelToNodeStore', () => {
|
||||
const providers = modelToNodeStore.getAllNodeProviders('checkpoints')
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should fallback to top-level folder for hierarchical model types', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
const providers = modelToNodeStore.getAllNodeProviders(
|
||||
'checkpoints/subfolder'
|
||||
)
|
||||
expect(providers).toHaveLength(2)
|
||||
expect(providers[0].nodeDef.name).toBe('CheckpointLoaderSimple')
|
||||
})
|
||||
})
|
||||
|
||||
describe('registerNodeProvider', () => {
|
||||
@@ -491,5 +540,16 @@ describe('useModelToNodeStore', () => {
|
||||
expect(modelToNodeStore.getNodeProvider('')).toBeUndefined()
|
||||
expect(modelToNodeStore.getAllNodeProviders('')).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle invalid input types gracefully', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
expect(modelToNodeStore.getNodeProvider(null as any)).toBeUndefined()
|
||||
expect(modelToNodeStore.getNodeProvider(undefined as any)).toBeUndefined()
|
||||
expect(modelToNodeStore.getNodeProvider(123 as any)).toBeUndefined()
|
||||
expect(modelToNodeStore.getAllNodeProviders(null as any)).toEqual([])
|
||||
expect(modelToNodeStore.getAllNodeProviders(undefined as any)).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,7 +9,7 @@ export class ModelNodeProvider {
|
||||
/** The node definition to use for this model. */
|
||||
public nodeDef: ComfyNodeDefImpl
|
||||
|
||||
/** The node input key for where to inside the model name. */
|
||||
/** The node input key for where to insert the model name. */
|
||||
public key: string
|
||||
|
||||
constructor(nodeDef: ComfyNodeDefImpl, key: string) {
|
||||
@@ -73,23 +73,51 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
|
||||
return nodeTypeToCategory.value[nodeType]
|
||||
}
|
||||
|
||||
/**
|
||||
* Find providers for modelType with hierarchical fallback.
|
||||
* Tries exact match first, then falls back to top-level segment (e.g., "parent/child" → "parent").
|
||||
* Note: Only falls back one level; "a/b/c" tries "a/b/c" then "a", not "a/b".
|
||||
*/
|
||||
function findProvidersWithFallback(
|
||||
modelType: string
|
||||
): ModelNodeProvider[] | undefined {
|
||||
if (!modelType || typeof modelType !== 'string') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const exactMatch = modelToNodeMap.value[modelType]
|
||||
if (exactMatch && exactMatch.length > 0) return exactMatch
|
||||
|
||||
const topLevel = modelType.split('/')[0]
|
||||
if (topLevel === modelType) return undefined
|
||||
|
||||
const fallback = modelToNodeMap.value[topLevel]
|
||||
|
||||
if (fallback && fallback.length > 0) return fallback
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the node provider for the given model type name.
|
||||
* Supports hierarchical lookups: if "parent/child" has no match, falls back to "parent".
|
||||
* @param modelType The name of the model type to get the node provider for.
|
||||
* @returns The node provider for the given model type name.
|
||||
*/
|
||||
function getNodeProvider(modelType: string): ModelNodeProvider | undefined {
|
||||
registerDefaults()
|
||||
return modelToNodeMap.value[modelType]?.[0]
|
||||
return findProvidersWithFallback(modelType)?.[0]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the list of all valid node providers for the given model type name.
|
||||
* Supports hierarchical lookups: if "parent/child" has no match, falls back to "parent".
|
||||
* @param modelType The name of the model type to get the node providers for.
|
||||
* @returns The list of all valid node providers for the given model type name.
|
||||
*/
|
||||
function getAllNodeProviders(modelType: string): ModelNodeProvider[] {
|
||||
registerDefaults()
|
||||
return modelToNodeMap.value[modelType] ?? []
|
||||
return findProvidersWithFallback(modelType) ?? []
|
||||
}
|
||||
/**
|
||||
* Register a node provider for the given model type name.
|
||||
@@ -153,6 +181,17 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
|
||||
'ADE_AnimateDiffLoRALoader',
|
||||
'name'
|
||||
)
|
||||
|
||||
// Chatterbox TTS nodes: empty key means the node auto-loads models without
|
||||
// a widget selector (createModelNodeFromAsset skips widget assignment)
|
||||
quickRegister('chatterbox/chatterbox', 'FL_ChatterboxTTS', '')
|
||||
quickRegister('chatterbox/chatterbox_turbo', 'FL_ChatterboxTurboTTS', '')
|
||||
quickRegister(
|
||||
'chatterbox/chatterbox_multilingual',
|
||||
'FL_ChatterboxMultilingualTTS',
|
||||
''
|
||||
)
|
||||
quickRegister('chatterbox/chatterbox_vc', 'FL_ChatterboxVC', '')
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user