diff --git a/src/constants/coreSettings.ts b/src/constants/coreSettings.ts index 3d7859f3b..068d102ad 100644 --- a/src/constants/coreSettings.ts +++ b/src/constants/coreSettings.ts @@ -977,8 +977,7 @@ export const CORE_SETTINGS: SettingParams[] = [ id: 'Comfy.Assets.UseAssetAPI', name: 'Use Asset API for model library', type: 'boolean', - tooltip: - 'Use new asset API instead of experiment endpoints for model browsing', + tooltip: 'Use new Asset API for model browsing', defaultValue: false, experimental: true } diff --git a/src/schemas/assetSchema.ts b/src/schemas/assetSchema.ts new file mode 100644 index 000000000..277efcbb0 --- /dev/null +++ b/src/schemas/assetSchema.ts @@ -0,0 +1,39 @@ +import { z } from 'zod' + +// Zod schemas for asset API validation +const zAsset = z.object({ + id: z.string(), + name: z.string(), + tags: z.array(z.string()), + size: z.number(), + created_at: z.string().optional() +}) + +const zAssetResponse = z.object({ + assets: z.array(zAsset).optional(), + total: z.number().optional(), + has_more: z.boolean().optional() +}) + +const zModelFolder = z.object({ + name: z.string(), + folders: z.array(z.string()) +}) + +// Export schemas following repository patterns +export const assetResponseSchema = zAssetResponse + +// Export types derived from Zod schemas +export type AssetResponse = z.infer +export type ModelFolder = z.infer + +// Common interfaces for API responses +export interface ModelFile { + name: string + pathIndex: number +} + +export interface ModelFolderInfo { + name: string + folders: string[] +} diff --git a/src/scripts/api.ts b/src/scripts/api.ts index 9a6a0b7d3..9480face6 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -30,6 +30,7 @@ import type { User, UserDataFullInfo } from '@/schemas/apiSchema' +import type { ModelFile, ModelFolderInfo } from '@/schemas/assetSchema' import type { ComfyApiWorkflow, ComfyWorkflowJSON, @@ -675,15 +676,14 @@ export class ComfyApi extends EventTarget { * Gets a list of model folder keys (eg ['checkpoints', 'loras', ...]) * @returns The list of model folder keys */ - async getModelFolders(): Promise<{ name: string; folders: string[] }[]> { + async getModelFolders(): Promise { const res = await this.fetchApi(`/experiment/models`) if (res.status === 404) { return [] } const folderBlacklist = ['configs', 'custom_nodes'] return (await res.json()).filter( - (folder: { name: string; folders: string[] }) => - !folderBlacklist.includes(folder.name) + (folder: ModelFolderInfo) => !folderBlacklist.includes(folder.name) ) } @@ -692,9 +692,7 @@ export class ComfyApi extends EventTarget { * @param {string} folder The folder to list models from, such as 'checkpoints' * @returns The list of model filenames within the specified folder */ - async getModels( - folder: string - ): Promise<{ name: string; pathIndex: number }[]> { + async getModels(folder: string): Promise { const res = await this.fetchApi(`/experiment/models/${folder}`) if (res.status === 404) { return [] diff --git a/src/services/assetService.ts b/src/services/assetService.ts index 90a4d8320..144beb231 100644 --- a/src/services/assetService.ts +++ b/src/services/assetService.ts @@ -1,63 +1,26 @@ +import { fromZodError } from 'zod-validation-error' + +import { + type AssetResponse, + type ModelFile, + type ModelFolder, + assetResponseSchema +} from '@/schemas/assetSchema' import { api } from '@/scripts/api' const ASSETS_ENDPOINT = '/assets' const MODELS_TAG = 'models' const MISSING_TAG = 'missing' -// Types for asset API responses -interface AssetResponse { - assets?: Asset[] - total?: number - has_more?: boolean -} - -interface Asset { - id: string - name: string - tags: string[] - size: number - created_at?: string -} - /** - * Type guard for validating asset structure + * Validates asset response data using Zod schema */ -function isValidAsset(asset: unknown): asset is Asset { - return ( - asset !== null && - typeof asset === 'object' && - 'id' in asset && - 'name' in asset && - 'tags' in asset && - Array.isArray((asset as Asset).tags) - ) -} +function validateAssetResponse(data: unknown): AssetResponse { + const result = assetResponseSchema.safeParse(data) + if (result.success) return result.data -/** - * Creates predicate for filtering assets by folder and excluding missing ones - */ -function createAssetFolderFilter(folder?: string) { - return (asset: unknown): asset is Asset => { - if (!isValidAsset(asset) || asset.tags.includes(MISSING_TAG)) { - return false - } - if (folder && !asset.tags.includes(folder)) { - return false - } - return true - } -} - -/** - * Creates predicate for filtering folder assets (requires name) - */ -function createFolderAssetFilter(folder: string) { - return (asset: unknown): asset is Asset => { - if (!isValidAsset(asset) || !asset.name) { - return false - } - return asset.tags.includes(folder) && !asset.tags.includes(MISSING_TAG) - } + const error = fromZodError(result.error) + throw new Error(`Invalid asset response against zod schema:\n${error}`) } /** @@ -66,7 +29,7 @@ function createFolderAssetFilter(folder: string) { */ function createAssetService() { /** - * Handles API response with consistent error handling + * Handles API response with consistent error handling and Zod validation */ async function handleAssetRequest( url: string, @@ -78,7 +41,8 @@ function createAssetService() { `Unable to load ${context}: Server returned ${res.status}. Please try again.` ) } - return await res.json() + const data = await res.json() + return validateAssetResponse(data) } /** * Gets a list of model folder keys from the asset API @@ -90,9 +54,7 @@ function createAssetService() { * * @returns The list of model folder keys */ - async function getAssetModelFolders(): Promise< - { name: string; folders: string[] }[] - > { + async function getAssetModelFolders(): Promise { const data = await handleAssetRequest( `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}`, 'model folders' @@ -102,22 +64,17 @@ function createAssetService() { const blacklistedDirectories = ['configs'] // Extract directory names from assets that actually exist, exclude missing assets - const discoveredFolders = new Set() - if (data?.assets) { - const directoryTags = data.assets - .filter(createAssetFolderFilter()) - .flatMap((asset) => asset.tags) - .filter( + const discoveredFolders = new Set( + data?.assets + ?.filter((asset) => !asset.tags.includes(MISSING_TAG)) + ?.flatMap((asset) => asset.tags) + ?.filter( (tag) => tag !== MODELS_TAG && !blacklistedDirectories.includes(tag) - ) - - for (const tag of directoryTags) { - discoveredFolders.add(tag) - } - } + ) ?? [] + ) // Return only discovered folders in alphabetical order - const sortedFolders = Array.from(discoveredFolders).sort() + const sortedFolders = Array.from(discoveredFolders).toSorted() return sortedFolders.map((name) => ({ name, folders: [] })) } @@ -126,20 +83,23 @@ function createAssetService() { * @param folder The folder to list models from, such as 'checkpoints' * @returns The list of model filenames within the specified folder */ - async function getAssetModels( - folder: string - ): Promise<{ name: string; pathIndex: number }[]> { + async function getAssetModels(folder: string): Promise { const data = await handleAssetRequest( `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}`, `models for ${folder}` ) - return data?.assets - ? data.assets.filter(createFolderAssetFilter(folder)).map((asset) => ({ + return ( + data?.assets + ?.filter( + (asset) => + !asset.tags.includes(MISSING_TAG) && asset.tags.includes(folder) + ) + ?.map((asset) => ({ name: asset.name, pathIndex: 0 - })) - : [] + })) ?? [] + ) } return { diff --git a/src/stores/modelStore.ts b/src/stores/modelStore.ts index f8dc5441d..34856c155 100644 --- a/src/stores/modelStore.ts +++ b/src/stores/modelStore.ts @@ -1,6 +1,7 @@ import { defineStore } from 'pinia' import { computed, ref } from 'vue' +import type { ModelFile } from '@/schemas/assetSchema' import { api } from '@/scripts/api' import { assetService } from '@/services/assetService' import { useSettingStore } from '@/stores/settingStore' @@ -157,9 +158,7 @@ export class ModelFolder { constructor( public directory: string, - private getModelsFunc: ( - folder: string - ) => Promise<{ name: string; pathIndex: number }[]> + private getModelsFunc: (folder: string) => Promise ) {} get key(): string { diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index d7c4a673b..29ffee3b0 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -83,19 +83,20 @@ describe('assetService', () => { expect(folderNames).not.toContain('configs') }) - it('should handle errors and empty responses', async () => { - // Empty response + it('should handle empty responses', async () => { mockApiResponse([]) const emptyResult = await assetService.getAssetModelFolders() expect(emptyResult).toHaveLength(0) + }) - // Network error + it('should handle network errors', async () => { vi.mocked(api.fetchApi).mockRejectedValueOnce(new Error('Network error')) await expect(assetService.getAssetModelFolders()).rejects.toThrow( 'Network error' ) + }) - // HTTP error + it('should handle HTTP errors', async () => { mockApiError(500) await expect(assetService.getAssetModelFolders()).rejects.toThrow( 'Unable to load model folders: Server returned 500. Please try again.' @@ -107,7 +108,6 @@ describe('assetService', () => { it('should return filtered models for folder', async () => { const assets = [ { ...MOCK_ASSETS.checkpoints, name: 'valid.safetensors' }, - { ...MOCK_ASSETS.checkpoints, name: undefined }, // Invalid name { ...MOCK_ASSETS.loras, name: 'lora.safetensors' }, // Wrong tag { id: 'uuid-4', diff --git a/tests-ui/tests/store/modelStore.test.ts b/tests-ui/tests/store/modelStore.test.ts index 26fa8be65..e77e9be0d 100644 --- a/tests-ui/tests/store/modelStore.test.ts +++ b/tests-ui/tests/store/modelStore.test.ts @@ -38,7 +38,9 @@ function enableMocks(useAssetAPI = false) { return false }) } - vi.mocked(useSettingStore).mockReturnValue(mockSettingStore as any) + vi.mocked(useSettingStore, { partial: true }).mockReturnValue( + mockSettingStore + ) // Mock experimental API - returns objects with name and folders properties vi.mocked(api.getModels).mockResolvedValue([ diff --git a/tsconfig.json b/tsconfig.json index 675a099c3..e183ed3f2 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,9 +1,9 @@ { "compilerOptions": { - "target": "ES2022", + "target": "ES2023", "useDefineForClassFields": true, "module": "ESNext", - "lib": ["ES2022", "DOM", "DOM.Iterable"], + "lib": ["ES2023", "ES2023.Array", "DOM", "DOM.Iterable"], "skipLibCheck": true, "incremental": true, "sourceMap": true,