diff --git a/src/platform/assets/schemas/assetSchema.ts b/src/platform/assets/schemas/assetSchema.ts index 2c051a30d..2238d8aa4 100644 --- a/src/platform/assets/schemas/assetSchema.ts +++ b/src/platform/assets/schemas/assetSchema.ts @@ -4,16 +4,16 @@ import { z } from 'zod' const zAsset = z.object({ id: z.string(), name: z.string(), - asset_hash: z.string().nullable(), + asset_hash: z.string().optional(), size: z.number(), - mime_type: z.string().nullable(), - tags: z.array(z.string()), + mime_type: z.string().optional(), + tags: z.array(z.string()).optional().default([]), + preview_id: z.string().nullable().optional(), preview_url: z.string().optional(), created_at: z.string(), updated_at: z.string().optional(), - last_access_time: z.string(), - user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs - preview_id: z.string().nullable().optional() + last_access_time: z.string().optional(), + user_metadata: z.record(z.unknown()).optional() // API allows arbitrary key-value pairs }) const zAssetResponse = z.object({ diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index 4f7925294..436dac26d 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -29,6 +29,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return new Set( Object.values(modelToNodeMap.value) .flat() + .filter((provider) => !!provider.nodeDef) .map((provider) => provider.nodeDef.name) ) }) @@ -38,6 +39,8 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { const lookup: Record = {} for (const [category, providers] of Object.entries(modelToNodeMap.value)) { for (const provider of providers) { + // Extension nodes may not be installed + if (!provider.nodeDef) continue // Only store the first category for each node type (matches current assetService behavior) if (!lookup[provider.nodeDef.name]) { lookup[provider.nodeDef.name] = category @@ -98,6 +101,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { nodeProvider: ModelNodeProvider ) { registerDefaults() + if (!nodeProvider.nodeDef) return if (!modelToNodeMap.value[modelType]) { modelToNodeMap.value[modelType] = [] } @@ -131,10 +135,24 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name') quickRegister('vae', 'VAELoader', 'vae_name') quickRegister('controlnet', 'ControlNetLoader', 'control_net_name') - quickRegister('unet', 'UNETLoader', 'unet_name') + quickRegister('diffusion_models', 'UNETLoader', 'unet_name') quickRegister('upscale_models', 'UpscaleModelLoader', 'model_name') - quickRegister('style_models', 'StyleModelLoader', 'style_model') + quickRegister('style_models', 'StyleModelLoader', 'style_model_name') quickRegister('gligen', 'GLIGENLoader', 'gligen_name') + quickRegister('clip_vision', 'CLIPVisionLoader', 'clip_name') + quickRegister('text_encoders', 'CLIPLoader', 'clip_name') + quickRegister('audio_encoders', 'AudioEncoderLoader', 'audio_encoder_name') + quickRegister('model_patches', 'ModelPatchLoader', 'name') + quickRegister( + 'animatediff_models', + 'ADE_LoadAnimateDiffModel', + 'model_name' + ) + quickRegister( + 'animatediff_motion_lora', + 'ADE_AnimateDiffLoRALoader', + 'name' + ) } return { diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index b07c34a41..c04f1764d 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -13,10 +13,16 @@ const EXPECTED_DEFAULT_TYPES = [ 'loras', 'vae', 'controlnet', - 'unet', + 'diffusion_models', 'upscale_models', 'style_models', - 'gligen' + 'gligen', + 'clip_vision', + 'text_encoders', + 'audio_encoders', + 'model_patches', + 'animatediff_models', + 'animatediff_motion_lora' ] as const type NodeDefStoreType = ReturnType @@ -48,7 +54,13 @@ const MOCK_NODE_NAMES = [ 'UNETLoader', 'UpscaleModelLoader', 'StyleModelLoader', - 'GLIGENLoader' + 'GLIGENLoader', + 'CLIPVisionLoader', + 'CLIPLoader', + 'AudioEncoderLoader', + 'ModelPatchLoader', + 'ADE_LoadAnimateDiffModel', + 'ADE_AnimateDiffLoRALoader' ] as const const mockNodeDefsByName = Object.fromEntries( @@ -84,7 +96,7 @@ describe('useModelToNodeStore', () => { const modelToNodeStore = useModelToNodeStore() modelToNodeStore.registerDefaults() expect(Object.keys(modelToNodeStore.modelToNodeMap)).toEqual( - expect.arrayContaining(['checkpoints', 'unet']) + expect.arrayContaining(['checkpoints', 'diffusion_models']) ) }) }) @@ -153,9 +165,10 @@ describe('useModelToNodeStore', () => { const modelToNodeStore = useModelToNodeStore() modelToNodeStore.registerDefaults() - const unetProviders = modelToNodeStore.getAllNodeProviders('unet') - expect(unetProviders).toHaveLength(1) - expect(unetProviders[0].nodeDef.name).toBe('UNETLoader') + const diffusionModelProviders = + modelToNodeStore.getAllNodeProviders('diffusion_models') + expect(diffusionModelProviders).toHaveLength(1) + expect(diffusionModelProviders[0].nodeDef.name).toBe('UNETLoader') }) it('should return empty array for unregistered model type', () => { @@ -173,6 +186,22 @@ describe('useModelToNodeStore', () => { }) describe('registerNodeProvider', () => { + it('should not register provider when nodeDef is undefined', () => { + const modelToNodeStore = useModelToNodeStore() + const providerWithoutNodeDef = new ModelNodeProvider( + undefined as any, + 'custom_key' + ) + + modelToNodeStore.registerNodeProvider( + 'custom_type', + providerWithoutNodeDef + ) + + const retrieved = modelToNodeStore.getNodeProvider('custom_type') + expect(retrieved).toBeUndefined() + }) + it('should register provider directly', () => { const modelToNodeStore = useModelToNodeStore() const nodeDefStore = useNodeDefStore() @@ -250,8 +279,20 @@ describe('useModelToNodeStore', () => { }).not.toThrow() const provider = modelToNodeStore.getNodeProvider('test_type') - // Optional chaining needed since getNodeProvider() can return undefined expect(provider?.nodeDef).toBeUndefined() + + expect(() => modelToNodeStore.getRegisteredNodeTypes()).not.toThrow() + expect(() => + modelToNodeStore.getCategoryForNodeType('NonExistentLoader') + ).not.toThrow() + + // Non-existent nodes are filtered out from registered types + const types = modelToNodeStore.getRegisteredNodeTypes() + expect(types.has('NonExistentLoader')).toBe(false) + + expect( + modelToNodeStore.getCategoryForNodeType('NonExistentLoader') + ).toBeUndefined() }) it('should allow multiple node classes for same model type', () => {