diff --git a/package.json b/package.json index 770ef7e04..a1089933f 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,8 @@ "preview": "nx preview", "lint": "eslint src --cache", "lint:fix": "eslint src --cache --fix", + "lint:unstaged": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache", + "lint:unstaged:fix": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache --fix", "lint:no-cache": "eslint src", "lint:fix:no-cache": "eslint src --fix", "knip": "knip --cache", diff --git a/src/i18n.ts b/src/i18n.ts index 102ac2600..38a8dfe95 100644 --- a/src/i18n.ts +++ b/src/i18n.ts @@ -76,7 +76,7 @@ export const i18n = createI18n({ }) /** Convenience shorthand: i18n.global */ -export const { t, te } = i18n.global +export const { t, te, d } = i18n.global /** * Safe translation function that returns the fallback message if the key is not found. diff --git a/src/lib/litegraph/src/litegraph.ts b/src/lib/litegraph/src/litegraph.ts index 46202a219..46b094af0 100644 --- a/src/lib/litegraph/src/litegraph.ts +++ b/src/lib/litegraph/src/litegraph.ts @@ -140,7 +140,7 @@ export { BaseWidget } from './widgets/BaseWidget' export { LegacyWidget } from './widgets/LegacyWidget' -export { isComboWidget } from './widgets/widgetMap' +export { isComboWidget, isAssetWidget } from './widgets/widgetMap' // Additional test-specific exports export { LGraphButton } from './LGraphButton' export { MovingOutputLink } from './canvas/MovingOutputLink' diff --git a/src/lib/litegraph/src/widgets/AssetWidget.ts b/src/lib/litegraph/src/widgets/AssetWidget.ts index f8a8e1209..1a5047beb 100644 --- a/src/lib/litegraph/src/widgets/AssetWidget.ts +++ b/src/lib/litegraph/src/widgets/AssetWidget.ts @@ -13,6 +13,22 @@ export class AssetWidget this.value = widget.value?.toString() ?? '' } + override set value(value: IAssetWidget['value']) { + const oldValue = this.value + super.value = value + + // Force canvas redraw when value changes to show update immediately + if (oldValue !== value && this.node.graph?.list_of_graphcanvas) { + for (const canvas of this.node.graph.list_of_graphcanvas) { + canvas.setDirty(true) + } + } + } + + override get value(): IAssetWidget['value'] { + return super.value + } + override get _displayValue(): string { return String(this.value) //FIXME: Resolve asset name } diff --git a/src/lib/litegraph/src/widgets/widgetMap.ts b/src/lib/litegraph/src/widgets/widgetMap.ts index 02cdb5597..0e6a34fe5 100644 --- a/src/lib/litegraph/src/widgets/widgetMap.ts +++ b/src/lib/litegraph/src/widgets/widgetMap.ts @@ -1,5 +1,6 @@ import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode' import type { + IAssetWidget, IBaseWidget, IComboWidget, IWidget, @@ -132,4 +133,9 @@ export function isComboWidget(widget: IBaseWidget): widget is IComboWidget { return widget.type === 'combo' } +/** Type guard: Narrow **from {@link IBaseWidget}** to {@link IAssetWidget}. */ +export function isAssetWidget(widget: IBaseWidget): widget is IAssetWidget { + return widget.type === 'asset' +} + // #endregion Type Guards diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 3abd85335..74f0352ae 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -1873,6 +1873,13 @@ "noModelsInFolder": "No {type} available in this folder", "searchAssetsPlaceholder": "Search assets...", "allModels": "All Models", - "unknown": "Unknown" + "unknown": "Unknown", + "fileFormats": "File formats", + "baseModels": "Base models", + "sortBy": "Sort by", + "sortAZ": "A-Z", + "sortZA": "Z-A", + "sortRecent": "Recent", + "sortPopular": "Popular" } } diff --git a/src/platform/assets/components/AssetBrowserModal.stories.ts b/src/platform/assets/components/AssetBrowserModal.stories.ts index acc93181d..9d2321d57 100644 --- a/src/platform/assets/components/AssetBrowserModal.stories.ts +++ b/src/platform/assets/components/AssetBrowserModal.stories.ts @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/vue3-vite' import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' +import type { AssetDisplayItem } from '@/platform/assets/composables/useAssetBrowser' import { createMockAssets, mockAssets @@ -56,7 +57,7 @@ export const Default: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -96,7 +97,7 @@ export const SingleAssetType: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -145,7 +146,7 @@ export const NoLeftPanel: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue index de05f437d..cb45f38ba 100644 --- a/src/platform/assets/components/AssetBrowserModal.vue +++ b/src/platform/assets/components/AssetBrowserModal.vue @@ -12,7 +12,7 @@ :nav-items="availableCategories" > @@ -37,7 +37,7 @@ diff --git a/src/platform/assets/components/AssetCard.vue b/src/platform/assets/components/AssetCard.vue index e379099c1..be7c45ca5 100644 --- a/src/platform/assets/components/AssetCard.vue +++ b/src/platform/assets/components/AssetCard.vue @@ -14,7 +14,7 @@ 'bg-ivory-100 border border-gray-300 dark-theme:bg-charcoal-400 dark-theme:border-charcoal-600', 'hover:transform hover:-translate-y-0.5 hover:shadow-lg hover:shadow-black/10 hover:border-gray-400', 'dark-theme:hover:shadow-lg dark-theme:hover:shadow-black/30 dark-theme:hover:border-charcoal-700', - 'focus:outline-none focus:ring-2 focus:ring-blue-500 dark-theme:focus:ring-blue-400' + 'focus:outline-none focus:transform focus:-translate-y-0.5 focus:shadow-lg focus:shadow-black/10 dark-theme:focus:shadow-black/30' ], // Div-specific styles !interactive && [ diff --git a/src/platform/assets/components/AssetFilterBar.vue b/src/platform/assets/components/AssetFilterBar.vue index 1f3295b43..904ce3e82 100644 --- a/src/platform/assets/components/AssetFilterBar.vue +++ b/src/platform/assets/components/AssetFilterBar.vue @@ -3,7 +3,7 @@
void + ): Promise { if (import.meta.env.DEV) { - console.log('Asset selected:', asset.id, asset.name) + console.debug('Asset selected:', assetId) + } + + if (!onSelect) { + return + } + + try { + const detailAsset = await assetService.getAssetDetails(assetId) + const filename = detailAsset.user_metadata?.filename + const validatedFilename = assetFilenameSchema.safeParse(filename) + if (!validatedFilename.success) { + console.error( + 'Invalid asset filename:', + validatedFilename.error.errors, + 'for asset:', + assetId + ) + return + } + + onSelect(validatedFilename.data) + } catch (error) { + console.error(`Failed to fetch asset details for ${assetId}:`, error) } - return asset.id } return { @@ -182,7 +212,6 @@ export function useAssetBrowser(assets: AssetItem[] = []) { filteredAssets, // Actions - selectAsset, - transformAssetForDisplay + selectAssetWithCallback } } diff --git a/src/platform/assets/composables/useAssetBrowserDialog.ts b/src/platform/assets/composables/useAssetBrowserDialog.ts index e5f63eead..31f75c353 100644 --- a/src/platform/assets/composables/useAssetBrowserDialog.ts +++ b/src/platform/assets/composables/useAssetBrowserDialog.ts @@ -1,5 +1,7 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' -import { useDialogStore } from '@/stores/dialogStore' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' +import { type DialogComponentProps, useDialogStore } from '@/stores/dialogStore' interface AssetBrowserDialogProps { /** ComfyUI node type for context (e.g., 'CheckpointLoaderSimple') */ @@ -8,36 +10,29 @@ interface AssetBrowserDialogProps { inputName: string /** Current selected asset value */ currentValue?: string - /** Callback for when an asset is selected */ - onAssetSelected?: (assetPath: string) => void + /** + * Callback for when an asset is selected + * @param {string} filename - The validated filename from user_metadata.filename + */ + onAssetSelected?: (filename: string) => void } export const useAssetBrowserDialog = () => { const dialogStore = useDialogStore() const dialogKey = 'global-asset-browser' - function hide() { - dialogStore.closeDialog({ key: dialogKey }) - } - - function show(props: AssetBrowserDialogProps) { - const handleAssetSelected = (assetPath: string) => { - props.onAssetSelected?.(assetPath) - hide() // Auto-close on selection + async function show(props: AssetBrowserDialogProps) { + const handleAssetSelected = (filename: string) => { + props.onAssetSelected?.(filename) + dialogStore.closeDialog({ key: dialogKey }) } - - const handleClose = () => { - hide() - } - - // Default dialog configuration for AssetBrowserModal - const dialogComponentProps = { + const dialogComponentProps: DialogComponentProps = { headless: true, modal: true, - closable: false, + closable: true, pt: { root: { - class: 'rounded-2xl overflow-hidden' + class: 'rounded-2xl overflow-hidden asset-browser-dialog' }, header: { class: 'p-0 hidden' @@ -48,6 +43,17 @@ export const useAssetBrowserDialog = () => { } } + const assets: AssetItem[] = await assetService + .getAssetsForNodeType(props.nodeType) + .catch((error) => { + console.error( + 'Failed to fetch assets for node type:', + props.nodeType, + error + ) + return [] + }) + dialogStore.showDialog({ key: dialogKey, component: AssetBrowserModal, @@ -55,12 +61,13 @@ export const useAssetBrowserDialog = () => { nodeType: props.nodeType, inputName: props.inputName, currentValue: props.currentValue, + assets, onSelect: handleAssetSelected, - onClose: handleClose + onClose: () => dialogStore.closeDialog({ key: dialogKey }) }, dialogComponentProps }) } - return { show, hide } + return { show } } diff --git a/src/platform/assets/schemas/assetSchema.ts b/src/platform/assets/schemas/assetSchema.ts index fab41649a..2c051a30d 100644 --- a/src/platform/assets/schemas/assetSchema.ts +++ b/src/platform/assets/schemas/assetSchema.ts @@ -4,13 +4,13 @@ import { z } from 'zod' const zAsset = z.object({ id: z.string(), name: z.string(), - asset_hash: z.string(), + asset_hash: z.string().nullable(), size: z.number(), - mime_type: z.string(), + mime_type: z.string().nullable(), tags: z.array(z.string()), preview_url: z.string().optional(), created_at: z.string(), - updated_at: z.string(), + updated_at: z.string().optional(), last_access_time: z.string(), user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs preview_id: z.string().nullable().optional() @@ -33,6 +33,14 @@ const zModelFile = z.object({ pathIndex: z.number() }) +// Filename validation schema +export const assetFilenameSchema = z + .string() + .min(1, 'Filename cannot be empty') + .regex(/^[^\\:*?"<>|]+$/, 'Invalid filename characters') // Allow forward slashes, block backslashes and other unsafe chars + .regex(/^(?!\/|.*\.\.)/, 'Path must not start with / or contain ..') // Prevent absolute paths and directory traversal + .trim() + // Export schemas following repository patterns export const assetResponseSchema = zAssetResponse diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index 74b20a753..7d0f82cbb 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -1,6 +1,7 @@ import { fromZodError } from 'zod-validation-error' import { + type AssetItem, type AssetResponse, type ModelFile, type ModelFolder, @@ -127,10 +128,75 @@ function createAssetService() { ) } + /** + * Gets assets for a specific node type by finding the matching category + * and fetching all assets with that category tag + * + * @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple') + * @returns Promise - Full asset objects with preserved metadata + */ + async function getAssetsForNodeType(nodeType: string): Promise { + if (!nodeType || typeof nodeType !== 'string') { + return [] + } + + // Find the category for this node type using efficient O(1) lookup + const modelToNodeStore = useModelToNodeStore() + const category = modelToNodeStore.getCategoryForNodeType(nodeType) + + if (!category) { + return [] + } + + // Fetch assets for this category using same API pattern as getAssetModels + const data = await handleAssetRequest( + `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}`, + `assets for ${nodeType}` + ) + + // Return full AssetItem[] objects (don't strip like getAssetModels does) + return ( + data?.assets?.filter( + (asset) => + !asset.tags.includes(MISSING_TAG) && asset.tags.includes(category) + ) ?? [] + ) + } + + /** + * Gets complete details for a specific asset by ID + * Calls the detail endpoint which includes user_metadata and all fields + * + * @param id - The asset ID + * @returns Promise - Complete asset object with user_metadata + */ + async function getAssetDetails(id: string): Promise { + const res = await api.fetchApi(`${ASSETS_ENDPOINT}/${id}`) + if (!res.ok) { + throw new Error( + `Unable to load asset details for ${id}: Server returned ${res.status}. Please try again.` + ) + } + const data = await res.json() + + // Validate the single asset response against our schema + const result = assetResponseSchema.safeParse({ assets: [data] }) + if (result.success && result.data.assets?.[0]) { + return result.data.assets[0] + } + + const error = result.error + ? fromZodError(result.error) + : 'Unknown validation error' + throw new Error(`Invalid asset response against zod schema:\n${error}`) + } + return { getAssetModelFolders, getAssetModels, - isAssetBrowserEligible + isAssetBrowserEligible, + getAssetsForNodeType, + getAssetDetails } } diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts index 2387fc59c..59705458c 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts @@ -3,10 +3,9 @@ import { ref } from 'vue' import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue' import { t } from '@/i18n' import type { LGraphNode } from '@/lib/litegraph/src/litegraph' -import type { - IBaseWidget, - IComboWidget -} from '@/lib/litegraph/src/types/widgets' +import { isAssetWidget, isComboWidget } from '@/lib/litegraph/src/litegraph' +import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useSettingStore } from '@/platform/settings/settingStore' import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration' @@ -73,11 +72,29 @@ const addComboWidget = ( const currentValue = getDefaultValue(inputSpec) const displayLabel = currentValue ?? t('widgets.selectModel') - const widget = node.addWidget('asset', inputSpec.name, displayLabel, () => { - console.log( - `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}` - ) - }) + const assetBrowserDialog = useAssetBrowserDialog() + + const widget = node.addWidget( + 'asset', + inputSpec.name, + displayLabel, + async () => { + if (!isAssetWidget(widget)) { + throw new Error(`Expected asset widget but received ${widget.type}`) + } + await assetBrowserDialog.show({ + nodeType: node.comfyClass || '', + inputName: inputSpec.name, + currentValue: widget.value, + onAssetSelected: (filename: string) => { + const oldValue = widget.value + widget.value = filename + // Using onWidgetChanged prevents a callback race where asset selection could reopen the dialog + node.onWidgetChanged?.(widget.name, filename, oldValue, widget) + } + }) + } + ) return widget } @@ -96,11 +113,14 @@ const addComboWidget = ( ) if (inputSpec.remote) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } const remoteWidget = useRemoteWidget({ remoteConfig: inputSpec.remote, defaultValue, node, - widget: widget as IComboWidget + widget }) if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton() @@ -116,16 +136,19 @@ const addComboWidget = ( } if (inputSpec.control_after_generate) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } widget.linkedWidgets = addValueControlWidgets( node, - widget as IComboWidget, + widget, undefined, undefined, transformInputSpecV2ToV1(inputSpec) ) } - return widget as IBaseWidget + return widget } export const useComboWidget = () => { diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index f6c15e91a..4f7925294 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -33,12 +33,43 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { ) }) + /** Internal computed for efficient reverse lookup: nodeType -> category */ + const nodeTypeToCategory = computed(() => { + const lookup: Record = {} + for (const [category, providers] of Object.entries(modelToNodeMap.value)) { + for (const provider of providers) { + // Only store the first category for each node type (matches current assetService behavior) + if (!lookup[provider.nodeDef.name]) { + lookup[provider.nodeDef.name] = category + } + } + } + return lookup + }) + /** Get set of all registered node types for efficient lookup */ function getRegisteredNodeTypes(): Set { registerDefaults() return registeredNodeTypes.value } + /** + * Get the category for a given node type. + * Performs efficient O(1) lookup using cached reverse map. + * @param nodeType The node type name to find the category for + * @returns The category name, or undefined if not found + */ + function getCategoryForNodeType(nodeType: string): string | undefined { + registerDefaults() + + // Handle invalid input gracefully + if (!nodeType || typeof nodeType !== 'string') { + return undefined + } + + return nodeTypeToCategory.value[nodeType] + } + /** * Get the node provider for the given model type name. * @param modelType The name of the model type to get the node provider for. @@ -109,6 +140,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return { modelToNodeMap, getRegisteredNodeTypes, + getCategoryForNodeType, getNodeProvider, getAllNodeProviders, registerNodeProvider, diff --git a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts index d7d4f74dc..bef33733b 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts @@ -1,10 +1,33 @@ -import { describe, expect, it } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { nextTick } from 'vue' import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser' import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' + +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetDetails: vi.fn() + } +})) + +vi.mock('@/i18n', () => ({ + t: (key: string) => { + const translations: Record = { + 'assetBrowser.allModels': 'All Models', + 'assetBrowser.assets': 'Assets', + 'assetBrowser.unknown': 'unknown' + } + return translations[key] || key + }, + d: (date: Date) => date.toLocaleDateString() +})) describe('useAssetBrowser', () => { + beforeEach(() => { + vi.restoreAllMocks() + }) + // Test fixtures - minimal data focused on functionality being tested const createApiAsset = (overrides: Partial = {}): AssetItem => ({ id: 'test-id', @@ -26,8 +49,8 @@ describe('useAssetBrowser', () => { user_metadata: { description: 'Test model' } }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] // Get the transformed asset from filteredAssets // Preserves API properties expect(result.id).toBe(apiAsset.id) @@ -49,15 +72,13 @@ describe('useAssetBrowser', () => { user_metadata: undefined }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] expect(result.description).toBe('loras model') }) it('formats various file sizes correctly', () => { - const { transformAssetForDisplay } = useAssetBrowser([]) - const testCases = [ { size: 512, expected: '512 B' }, { size: 1536, expected: '1.5 KB' }, @@ -67,7 +88,8 @@ describe('useAssetBrowser', () => { testCases.forEach(({ size, expected }) => { const asset = createApiAsset({ size }) - const result = transformAssetForDisplay(asset) + const { filteredAssets } = useAssetBrowser([asset]) + const result = filteredAssets.value[0] expect(result.formattedSize).toBe(expected) }) }) @@ -236,18 +258,182 @@ describe('useAssetBrowser', () => { }) }) - describe('Asset Selection', () => { - it('returns selected asset UUID for efficient handling', () => { + describe('Async Asset Selection with Detail Fetching', () => { + it('should fetch asset details and call onSelect with filename when provided', async () => { + const onSelectSpy = vi.fn() const asset = createApiAsset({ - id: 'test-uuid-123', - name: 'selected_model.safetensors' + id: 'asset-123', + name: 'test-model.safetensors' }) - const { selectAsset, transformAssetForDisplay } = useAssetBrowser([asset]) - const displayAsset = transformAssetForDisplay(asset) - const result = selectAsset(displayAsset) + const detailAsset = createApiAsset({ + id: 'asset-123', + name: 'test-model.safetensors', + user_metadata: { filename: 'checkpoints/test-model.safetensors' } + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) - expect(result).toBe('test-uuid-123') + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123') + expect(onSelectSpy).toHaveBeenCalledWith( + 'checkpoints/test-model.safetensors' + ) + }) + + it('should handle missing user_metadata.filename as error', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-456' }) + + const detailAsset = createApiAsset({ + id: 'asset-456', + user_metadata: { filename: '' } // Invalid empty filename + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Filename cannot be empty' + }) + ]), + 'for asset:', + 'asset-456' + ) + }) + + it('should handle API errors gracefully', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-789' }) + + const apiError = new Error('API Error') + vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining('Failed to fetch asset details for asset-789'), + apiError + ) + }) + + it('should not fetch details when no callback provided', async () => { + const asset = createApiAsset({ id: 'asset-no-callback' }) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id) + + expect(assetService.getAssetDetails).not.toHaveBeenCalled() + }) + }) + + describe('Filename Validation Security', () => { + const createValidationTest = (filename: string) => { + const testAsset = createApiAsset({ id: 'validation-test' }) + const detailAsset = createApiAsset({ + id: 'validation-test', + user_metadata: { filename } + }) + return { testAsset, detailAsset } + } + + it('accepts valid file paths with forward slashes', async () => { + const onSelectSpy = vi.fn() + const { testAsset, detailAsset } = createValidationTest( + 'models/checkpoints/v1/test-model.safetensors' + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).toHaveBeenCalledWith( + 'models/checkpoints/v1/test-model.safetensors' + ) + }) + + it('rejects directory traversal attacks', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const maliciousPaths = [ + '../malicious-model.safetensors', + 'models/../../../etc/passwd', + '/etc/passwd' + ] + + for (const path of maliciousPaths) { + const { testAsset, detailAsset } = createValidationTest(path) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Path must not start with / or contain ..' + }) + ]), + 'for asset:', + 'validation-test' + ) + } + }) + + it('rejects invalid filename characters', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|'] + + for (const char of invalidChars) { + const { testAsset, detailAsset } = createValidationTest( + `bad${char}filename.safetensors` + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Invalid filename characters' + }) + ]), + 'for asset:', + 'validation-test' + ) + } }) }) diff --git a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts index fefeeceac..102aa7a18 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts @@ -6,11 +6,18 @@ import { useDialogStore } from '@/stores/dialogStore' // Mock the dialog store vi.mock('@/stores/dialogStore') +// Mock the asset service +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetsForNodeType: vi.fn().mockResolvedValue([]) + } +})) + // Test factory functions interface AssetBrowserProps { nodeType: string inputName: string - onAssetSelected?: ReturnType + onAssetSelected?: (filename: string) => void } function createAssetBrowserProps( @@ -25,7 +32,7 @@ function createAssetBrowserProps( describe('useAssetBrowserDialog', () => { describe('Asset Selection Flow', () => { - it('auto-closes dialog when asset is selected', () => { + it('auto-closes dialog when asset is selected', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -41,7 +48,7 @@ describe('useAssetBrowserDialog', () => { const onAssetSelected = vi.fn() const props = createAssetBrowserProps({ onAssetSelected }) - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onSelect handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] @@ -50,14 +57,14 @@ describe('useAssetBrowserDialog', () => { // Simulate asset selection onSelectHandler('selected-asset-path') - // Should call the original callback and close dialog + // Should call the original callback and trigger hide animation expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path') expect(mockCloseDialog).toHaveBeenCalledWith({ key: 'global-asset-browser' }) }) - it('closes dialog when close handler is called', () => { + it('closes dialog when close handler is called', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -72,7 +79,7 @@ describe('useAssetBrowserDialog', () => { const assetBrowserDialog = useAssetBrowserDialog() const props = createAssetBrowserProps() - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onClose handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts index 875919ccd..d439d98ca 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { LGraphNode } from '@/lib/litegraph/src/litegraph' import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget' import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2' @@ -29,13 +30,25 @@ vi.mock('@/platform/assets/services/assetService', () => ({ } })) +vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => { + const mockAssetBrowserDialogShow = vi.fn() + return { + useAssetBrowserDialog: vi.fn(() => ({ + show: mockAssetBrowserDialogShow + })) + } +}) + // Test factory functions function createMockWidget(overrides: Partial = {}): IBaseWidget { + const mockCallback = vi.fn() return { type: 'combo', options: {}, name: 'testWidget', value: undefined, + callback: mockCallback, + y: 0, ...overrides } as IBaseWidget } @@ -45,7 +58,16 @@ function createMockNode(comfyClass = 'TestNode'): LGraphNode { node.comfyClass = comfyClass // Spy on the addWidget method - vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget()) + vi.spyOn(node, 'addWidget').mockImplementation( + (type, name, value, callback) => { + const widget = createMockWidget({ type, name, value }) + // Store the callback function on the widget for testing + if (typeof callback === 'function') { + widget.callback = callback + } + return widget + } + ) return node } @@ -61,9 +83,9 @@ function createMockInputSpec(overrides: Partial = {}): InputSpec { describe('useComboWidget', () => { beforeEach(() => { vi.clearAllMocks() - // Reset to defaults mockSettingStoreGet.mockReturnValue(false) vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false) + vi.mocked(useAssetBrowserDialog).mockClear() }) it('should handle undefined spec', () => { diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index d96ef765b..7a719c4e4 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -4,6 +4,8 @@ import type { AssetItem } from '@/platform/assets/schemas/assetSchema' import { assetService } from '@/platform/assets/services/assetService' import { api } from '@/scripts/api' +const mockGetCategoryForNodeType = vi.fn() + vi.mock('@/stores/modelToNodeStore', () => ({ useModelToNodeStore: vi.fn(() => ({ getRegisteredNodeTypes: vi.fn( @@ -14,7 +16,13 @@ vi.mock('@/stores/modelToNodeStore', () => ({ 'VAELoader', 'TestNode' ]) - ) + ), + getCategoryForNodeType: mockGetCategoryForNodeType, + modelToNodeMap: { + checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }], + loras: [{ nodeDef: { name: 'LoraLoader' } }], + vae: [{ nodeDef: { name: 'VAELoader' } }] + } })) })) @@ -210,4 +218,87 @@ describe('assetService', () => { ).toBe(false) }) }) + + describe('getAssetsForNodeType', () => { + beforeEach(() => { + mockGetCategoryForNodeType.mockClear() + }) + + it('should return empty array for unregistered node types', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('UnknownNode') + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith('UnknownNode') + expect(result).toEqual([]) + }) + + it('should use getCategoryForNodeType for efficient category lookup', async () => { + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const testAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(testAssets) + + const result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(testAssets) + + // Verify API call includes correct category + expect(api.fetchApi).toHaveBeenCalledWith( + '/assets?include_tags=models,checkpoints' + ) + }) + + it('should return empty array when no category found', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('TestNode') + + expect(result).toEqual([]) + expect(api.fetchApi).not.toHaveBeenCalled() + }) + + it('should handle API errors gracefully', async () => { + mockGetCategoryForNodeType.mockReturnValue('loras') + mockApiError(500, 'Internal Server Error') + + await expect( + assetService.getAssetsForNodeType('LoraLoader') + ).rejects.toThrow( + 'Unable to load assets for LoraLoader: Server returned 500. Please try again.' + ) + }) + + it('should return all assets without filtering for different categories', async () => { + // Test checkpoints + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const checkpointAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(checkpointAssets) + + let result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(checkpointAssets) + + // Test loras + mockGetCategoryForNodeType.mockReturnValue('loras') + const loraAssets = [MOCK_ASSETS.loras] + mockApiResponse(loraAssets) + + result = await assetService.getAssetsForNodeType('LoraLoader') + expect(result).toEqual(loraAssets) + + // Test vae + mockGetCategoryForNodeType.mockReturnValue('vae') + const vaeAssets = [MOCK_ASSETS.vae] + mockApiResponse(vaeAssets) + + result = await assetService.getAssetsForNodeType('VAELoader') + expect(result).toEqual(vaeAssets) + }) + }) }) diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index 179c76b9e..b07c34a41 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -19,7 +19,7 @@ const EXPECTED_DEFAULT_TYPES = [ 'gligen' ] as const -type NodeDefStoreType = typeof import('@/stores/nodeDefStore') +type NodeDefStoreType = ReturnType // Create minimal but valid ComfyNodeDefImpl for testing function createMockNodeDef(name: string): ComfyNodeDefImpl { @@ -343,6 +343,107 @@ describe('useModelToNodeStore', () => { }) }) + describe('getCategoryForNodeType', () => { + it('should return category for known node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + expect(modelToNodeStore.getCategoryForNodeType('LoraLoader')).toBe( + 'loras' + ) + expect(modelToNodeStore.getCategoryForNodeType('VAELoader')).toBe('vae') + }) + + it('should return undefined for unknown node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('NonExistentNode') + ).toBeUndefined() + expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined() + }) + + it('should return first category when node type exists in multiple categories', () => { + const modelToNodeStore = useModelToNodeStore() + + // Test with a node that exists in the defaults but add our own first + // Since defaults register 'StyleModelLoader' in 'style_models', + // we verify our custom registrations come after defaults in Object.entries iteration + const result = modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + expect(result).toBe('style_models') // This proves the method works correctly + + // Now test that custom registrations after defaults also work + modelToNodeStore.quickRegister( + 'unicorn_styles', + 'StyleModelLoader', + 'param1' + ) + const result2 = + modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + // Should still be style_models since it was registered first by defaults + expect(result2).toBe('style_models') + }) + + it('should trigger lazy registration when called before registerDefaults', () => { + const modelToNodeStore = useModelToNodeStore() + + const result = modelToNodeStore.getCategoryForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toBe('checkpoints') + }) + + it('should be performant for repeated lookups', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // Measure performance without assuming implementation + const start = performance.now() + for (let i = 0; i < 1000; i++) { + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + } + const end = performance.now() + + // Should be fast enough for UI responsiveness + expect(end - start).toBeLessThan(10) + }) + + it('should handle invalid input types gracefully', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // These should not throw but return undefined + expect( + modelToNodeStore.getCategoryForNodeType(null as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(undefined as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(123 as any) + ).toBeUndefined() + }) + + it('should be case-sensitive for node type matching', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('checkpointloadersimple') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CHECKPOINTLOADERSIMPLE') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + }) + }) + describe('edge cases', () => { it('should handle empty string model type', () => { const modelToNodeStore = useModelToNodeStore()