From fd125917563e68e3126ede2b53e82594b0f0303c Mon Sep 17 00:00:00 2001 From: Arjan Singh <1598641+arjansingh@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:44:18 -0700 Subject: [PATCH] [feat] integrate asset browser with widget system (#5629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add asset browser dialog integration for combo widgets with full animation support and proper state management. (Thank you Claude from saving me me from merge conflict hell on this one.) ## Changes - Widget integration: combo widgets now use AssetBrowserModal for eligible asset types - Dialog animations: added animateHide() for smooth close transitions - Async operations: proper sequencing of widget updates and dialog animations - Service layer: added getAssetsForNodeType() and getAssetDetails() methods - Type safety: comprehensive TypeScript types and error handling - Test coverage: unit tests for all new functionality - Bonus: fixed the hardcoded labels in AssetFilterBar Widget behavior: - Shows asset browser button for eligible widgets when asset API enabled - Handles asset selection with proper callback sequencing - Maintains widget value updates and litegraph notification ## Review Focus I will call out some stuff inline. ## Screenshots https://github.com/user-attachments/assets/9d3a72cf-d2b0-445f-8022-4c49daa04637 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-5629-feat-integrate-asset-browser-with-widget-system-2726d73d365081a9a98be9a2307aee0b) by [Unito](https://www.unito.io) --------- Co-authored-by: Claude Co-authored-by: GitHub Action --- package.json | 2 + src/i18n.ts | 2 +- src/lib/litegraph/src/litegraph.ts | 2 +- src/lib/litegraph/src/widgets/AssetWidget.ts | 16 ++ src/lib/litegraph/src/widgets/widgetMap.ts | 6 + src/locales/en/main.json | 9 +- .../components/AssetBrowserModal.stories.ts | 7 +- .../assets/components/AssetBrowserModal.vue | 24 +- src/platform/assets/components/AssetCard.vue | 2 +- .../assets/components/AssetFilterBar.vue | 15 +- .../assets/composables/useAssetBrowser.ts | 47 +++- .../composables/useAssetBrowserDialog.ts | 51 ++-- src/platform/assets/schemas/assetSchema.ts | 14 +- src/platform/assets/services/assetService.ts | 68 +++++- .../widgets/composables/useComboWidget.ts | 47 +++- src/stores/modelToNodeStore.ts | 32 +++ .../composables/useAssetBrowser.test.ts | 218 ++++++++++++++++-- .../composables/useAssetBrowserDialog.test.ts | 19 +- .../composables/useComboWidget.test.ts | 26 ++- tests-ui/tests/services/assetService.test.ts | 93 +++++++- tests-ui/tests/store/modelToNodeStore.test.ts | 103 ++++++++- 21 files changed, 701 insertions(+), 102 deletions(-) diff --git a/package.json b/package.json index 770ef7e04..a1089933f 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,8 @@ "preview": "nx preview", "lint": "eslint src --cache", "lint:fix": "eslint src --cache --fix", + "lint:unstaged": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache", + "lint:unstaged:fix": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache --fix", "lint:no-cache": "eslint src", "lint:fix:no-cache": "eslint src --fix", "knip": "knip --cache", diff --git a/src/i18n.ts b/src/i18n.ts index 102ac2600..38a8dfe95 100644 --- a/src/i18n.ts +++ b/src/i18n.ts @@ -76,7 +76,7 @@ export const i18n = createI18n({ }) /** Convenience shorthand: i18n.global */ -export const { t, te } = i18n.global +export const { t, te, d } = i18n.global /** * Safe translation function that returns the fallback message if the key is not found. diff --git a/src/lib/litegraph/src/litegraph.ts b/src/lib/litegraph/src/litegraph.ts index 46202a219..46b094af0 100644 --- a/src/lib/litegraph/src/litegraph.ts +++ b/src/lib/litegraph/src/litegraph.ts @@ -140,7 +140,7 @@ export { BaseWidget } from './widgets/BaseWidget' export { LegacyWidget } from './widgets/LegacyWidget' -export { isComboWidget } from './widgets/widgetMap' +export { isComboWidget, isAssetWidget } from './widgets/widgetMap' // Additional test-specific exports export { LGraphButton } from './LGraphButton' export { MovingOutputLink } from './canvas/MovingOutputLink' diff --git a/src/lib/litegraph/src/widgets/AssetWidget.ts b/src/lib/litegraph/src/widgets/AssetWidget.ts index f8a8e1209..1a5047beb 100644 --- a/src/lib/litegraph/src/widgets/AssetWidget.ts +++ b/src/lib/litegraph/src/widgets/AssetWidget.ts @@ -13,6 +13,22 @@ export class AssetWidget this.value = widget.value?.toString() ?? '' } + override set value(value: IAssetWidget['value']) { + const oldValue = this.value + super.value = value + + // Force canvas redraw when value changes to show update immediately + if (oldValue !== value && this.node.graph?.list_of_graphcanvas) { + for (const canvas of this.node.graph.list_of_graphcanvas) { + canvas.setDirty(true) + } + } + } + + override get value(): IAssetWidget['value'] { + return super.value + } + override get _displayValue(): string { return String(this.value) //FIXME: Resolve asset name } diff --git a/src/lib/litegraph/src/widgets/widgetMap.ts b/src/lib/litegraph/src/widgets/widgetMap.ts index 02cdb5597..0e6a34fe5 100644 --- a/src/lib/litegraph/src/widgets/widgetMap.ts +++ b/src/lib/litegraph/src/widgets/widgetMap.ts @@ -1,5 +1,6 @@ import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode' import type { + IAssetWidget, IBaseWidget, IComboWidget, IWidget, @@ -132,4 +133,9 @@ export function isComboWidget(widget: IBaseWidget): widget is IComboWidget { return widget.type === 'combo' } +/** Type guard: Narrow **from {@link IBaseWidget}** to {@link IAssetWidget}. */ +export function isAssetWidget(widget: IBaseWidget): widget is IAssetWidget { + return widget.type === 'asset' +} + // #endregion Type Guards diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 3abd85335..74f0352ae 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -1873,6 +1873,13 @@ "noModelsInFolder": "No {type} available in this folder", "searchAssetsPlaceholder": "Search assets...", "allModels": "All Models", - "unknown": "Unknown" + "unknown": "Unknown", + "fileFormats": "File formats", + "baseModels": "Base models", + "sortBy": "Sort by", + "sortAZ": "A-Z", + "sortZA": "Z-A", + "sortRecent": "Recent", + "sortPopular": "Popular" } } diff --git a/src/platform/assets/components/AssetBrowserModal.stories.ts b/src/platform/assets/components/AssetBrowserModal.stories.ts index acc93181d..9d2321d57 100644 --- a/src/platform/assets/components/AssetBrowserModal.stories.ts +++ b/src/platform/assets/components/AssetBrowserModal.stories.ts @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/vue3-vite' import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' +import type { AssetDisplayItem } from '@/platform/assets/composables/useAssetBrowser' import { createMockAssets, mockAssets @@ -56,7 +57,7 @@ export const Default: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -96,7 +97,7 @@ export const SingleAssetType: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -145,7 +146,7 @@ export const NoLeftPanel: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue index de05f437d..cb45f38ba 100644 --- a/src/platform/assets/components/AssetBrowserModal.vue +++ b/src/platform/assets/components/AssetBrowserModal.vue @@ -12,7 +12,7 @@ :nav-items="availableCategories" > @@ -37,7 +37,7 @@ diff --git a/src/platform/assets/components/AssetCard.vue b/src/platform/assets/components/AssetCard.vue index e379099c1..be7c45ca5 100644 --- a/src/platform/assets/components/AssetCard.vue +++ b/src/platform/assets/components/AssetCard.vue @@ -14,7 +14,7 @@ 'bg-ivory-100 border border-gray-300 dark-theme:bg-charcoal-400 dark-theme:border-charcoal-600', 'hover:transform hover:-translate-y-0.5 hover:shadow-lg hover:shadow-black/10 hover:border-gray-400', 'dark-theme:hover:shadow-lg dark-theme:hover:shadow-black/30 dark-theme:hover:border-charcoal-700', - 'focus:outline-none focus:ring-2 focus:ring-blue-500 dark-theme:focus:ring-blue-400' + 'focus:outline-none focus:transform focus:-translate-y-0.5 focus:shadow-lg focus:shadow-black/10 dark-theme:focus:shadow-black/30' ], // Div-specific styles !interactive && [ diff --git a/src/platform/assets/components/AssetFilterBar.vue b/src/platform/assets/components/AssetFilterBar.vue index 1f3295b43..904ce3e82 100644 --- a/src/platform/assets/components/AssetFilterBar.vue +++ b/src/platform/assets/components/AssetFilterBar.vue @@ -3,7 +3,7 @@
void + ): Promise { if (import.meta.env.DEV) { - console.log('Asset selected:', asset.id, asset.name) + console.debug('Asset selected:', assetId) + } + + if (!onSelect) { + return + } + + try { + const detailAsset = await assetService.getAssetDetails(assetId) + const filename = detailAsset.user_metadata?.filename + const validatedFilename = assetFilenameSchema.safeParse(filename) + if (!validatedFilename.success) { + console.error( + 'Invalid asset filename:', + validatedFilename.error.errors, + 'for asset:', + assetId + ) + return + } + + onSelect(validatedFilename.data) + } catch (error) { + console.error(`Failed to fetch asset details for ${assetId}:`, error) } - return asset.id } return { @@ -182,7 +212,6 @@ export function useAssetBrowser(assets: AssetItem[] = []) { filteredAssets, // Actions - selectAsset, - transformAssetForDisplay + selectAssetWithCallback } } diff --git a/src/platform/assets/composables/useAssetBrowserDialog.ts b/src/platform/assets/composables/useAssetBrowserDialog.ts index e5f63eead..31f75c353 100644 --- a/src/platform/assets/composables/useAssetBrowserDialog.ts +++ b/src/platform/assets/composables/useAssetBrowserDialog.ts @@ -1,5 +1,7 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' -import { useDialogStore } from '@/stores/dialogStore' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' +import { type DialogComponentProps, useDialogStore } from '@/stores/dialogStore' interface AssetBrowserDialogProps { /** ComfyUI node type for context (e.g., 'CheckpointLoaderSimple') */ @@ -8,36 +10,29 @@ interface AssetBrowserDialogProps { inputName: string /** Current selected asset value */ currentValue?: string - /** Callback for when an asset is selected */ - onAssetSelected?: (assetPath: string) => void + /** + * Callback for when an asset is selected + * @param {string} filename - The validated filename from user_metadata.filename + */ + onAssetSelected?: (filename: string) => void } export const useAssetBrowserDialog = () => { const dialogStore = useDialogStore() const dialogKey = 'global-asset-browser' - function hide() { - dialogStore.closeDialog({ key: dialogKey }) - } - - function show(props: AssetBrowserDialogProps) { - const handleAssetSelected = (assetPath: string) => { - props.onAssetSelected?.(assetPath) - hide() // Auto-close on selection + async function show(props: AssetBrowserDialogProps) { + const handleAssetSelected = (filename: string) => { + props.onAssetSelected?.(filename) + dialogStore.closeDialog({ key: dialogKey }) } - - const handleClose = () => { - hide() - } - - // Default dialog configuration for AssetBrowserModal - const dialogComponentProps = { + const dialogComponentProps: DialogComponentProps = { headless: true, modal: true, - closable: false, + closable: true, pt: { root: { - class: 'rounded-2xl overflow-hidden' + class: 'rounded-2xl overflow-hidden asset-browser-dialog' }, header: { class: 'p-0 hidden' @@ -48,6 +43,17 @@ export const useAssetBrowserDialog = () => { } } + const assets: AssetItem[] = await assetService + .getAssetsForNodeType(props.nodeType) + .catch((error) => { + console.error( + 'Failed to fetch assets for node type:', + props.nodeType, + error + ) + return [] + }) + dialogStore.showDialog({ key: dialogKey, component: AssetBrowserModal, @@ -55,12 +61,13 @@ export const useAssetBrowserDialog = () => { nodeType: props.nodeType, inputName: props.inputName, currentValue: props.currentValue, + assets, onSelect: handleAssetSelected, - onClose: handleClose + onClose: () => dialogStore.closeDialog({ key: dialogKey }) }, dialogComponentProps }) } - return { show, hide } + return { show } } diff --git a/src/platform/assets/schemas/assetSchema.ts b/src/platform/assets/schemas/assetSchema.ts index fab41649a..2c051a30d 100644 --- a/src/platform/assets/schemas/assetSchema.ts +++ b/src/platform/assets/schemas/assetSchema.ts @@ -4,13 +4,13 @@ import { z } from 'zod' const zAsset = z.object({ id: z.string(), name: z.string(), - asset_hash: z.string(), + asset_hash: z.string().nullable(), size: z.number(), - mime_type: z.string(), + mime_type: z.string().nullable(), tags: z.array(z.string()), preview_url: z.string().optional(), created_at: z.string(), - updated_at: z.string(), + updated_at: z.string().optional(), last_access_time: z.string(), user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs preview_id: z.string().nullable().optional() @@ -33,6 +33,14 @@ const zModelFile = z.object({ pathIndex: z.number() }) +// Filename validation schema +export const assetFilenameSchema = z + .string() + .min(1, 'Filename cannot be empty') + .regex(/^[^\\:*?"<>|]+$/, 'Invalid filename characters') // Allow forward slashes, block backslashes and other unsafe chars + .regex(/^(?!\/|.*\.\.)/, 'Path must not start with / or contain ..') // Prevent absolute paths and directory traversal + .trim() + // Export schemas following repository patterns export const assetResponseSchema = zAssetResponse diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index 74b20a753..7d0f82cbb 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -1,6 +1,7 @@ import { fromZodError } from 'zod-validation-error' import { + type AssetItem, type AssetResponse, type ModelFile, type ModelFolder, @@ -127,10 +128,75 @@ function createAssetService() { ) } + /** + * Gets assets for a specific node type by finding the matching category + * and fetching all assets with that category tag + * + * @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple') + * @returns Promise - Full asset objects with preserved metadata + */ + async function getAssetsForNodeType(nodeType: string): Promise { + if (!nodeType || typeof nodeType !== 'string') { + return [] + } + + // Find the category for this node type using efficient O(1) lookup + const modelToNodeStore = useModelToNodeStore() + const category = modelToNodeStore.getCategoryForNodeType(nodeType) + + if (!category) { + return [] + } + + // Fetch assets for this category using same API pattern as getAssetModels + const data = await handleAssetRequest( + `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}`, + `assets for ${nodeType}` + ) + + // Return full AssetItem[] objects (don't strip like getAssetModels does) + return ( + data?.assets?.filter( + (asset) => + !asset.tags.includes(MISSING_TAG) && asset.tags.includes(category) + ) ?? [] + ) + } + + /** + * Gets complete details for a specific asset by ID + * Calls the detail endpoint which includes user_metadata and all fields + * + * @param id - The asset ID + * @returns Promise - Complete asset object with user_metadata + */ + async function getAssetDetails(id: string): Promise { + const res = await api.fetchApi(`${ASSETS_ENDPOINT}/${id}`) + if (!res.ok) { + throw new Error( + `Unable to load asset details for ${id}: Server returned ${res.status}. Please try again.` + ) + } + const data = await res.json() + + // Validate the single asset response against our schema + const result = assetResponseSchema.safeParse({ assets: [data] }) + if (result.success && result.data.assets?.[0]) { + return result.data.assets[0] + } + + const error = result.error + ? fromZodError(result.error) + : 'Unknown validation error' + throw new Error(`Invalid asset response against zod schema:\n${error}`) + } + return { getAssetModelFolders, getAssetModels, - isAssetBrowserEligible + isAssetBrowserEligible, + getAssetsForNodeType, + getAssetDetails } } diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts index 2387fc59c..59705458c 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts @@ -3,10 +3,9 @@ import { ref } from 'vue' import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue' import { t } from '@/i18n' import type { LGraphNode } from '@/lib/litegraph/src/litegraph' -import type { - IBaseWidget, - IComboWidget -} from '@/lib/litegraph/src/types/widgets' +import { isAssetWidget, isComboWidget } from '@/lib/litegraph/src/litegraph' +import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useSettingStore } from '@/platform/settings/settingStore' import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration' @@ -73,11 +72,29 @@ const addComboWidget = ( const currentValue = getDefaultValue(inputSpec) const displayLabel = currentValue ?? t('widgets.selectModel') - const widget = node.addWidget('asset', inputSpec.name, displayLabel, () => { - console.log( - `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}` - ) - }) + const assetBrowserDialog = useAssetBrowserDialog() + + const widget = node.addWidget( + 'asset', + inputSpec.name, + displayLabel, + async () => { + if (!isAssetWidget(widget)) { + throw new Error(`Expected asset widget but received ${widget.type}`) + } + await assetBrowserDialog.show({ + nodeType: node.comfyClass || '', + inputName: inputSpec.name, + currentValue: widget.value, + onAssetSelected: (filename: string) => { + const oldValue = widget.value + widget.value = filename + // Using onWidgetChanged prevents a callback race where asset selection could reopen the dialog + node.onWidgetChanged?.(widget.name, filename, oldValue, widget) + } + }) + } + ) return widget } @@ -96,11 +113,14 @@ const addComboWidget = ( ) if (inputSpec.remote) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } const remoteWidget = useRemoteWidget({ remoteConfig: inputSpec.remote, defaultValue, node, - widget: widget as IComboWidget + widget }) if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton() @@ -116,16 +136,19 @@ const addComboWidget = ( } if (inputSpec.control_after_generate) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } widget.linkedWidgets = addValueControlWidgets( node, - widget as IComboWidget, + widget, undefined, undefined, transformInputSpecV2ToV1(inputSpec) ) } - return widget as IBaseWidget + return widget } export const useComboWidget = () => { diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index f6c15e91a..4f7925294 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -33,12 +33,43 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { ) }) + /** Internal computed for efficient reverse lookup: nodeType -> category */ + const nodeTypeToCategory = computed(() => { + const lookup: Record = {} + for (const [category, providers] of Object.entries(modelToNodeMap.value)) { + for (const provider of providers) { + // Only store the first category for each node type (matches current assetService behavior) + if (!lookup[provider.nodeDef.name]) { + lookup[provider.nodeDef.name] = category + } + } + } + return lookup + }) + /** Get set of all registered node types for efficient lookup */ function getRegisteredNodeTypes(): Set { registerDefaults() return registeredNodeTypes.value } + /** + * Get the category for a given node type. + * Performs efficient O(1) lookup using cached reverse map. + * @param nodeType The node type name to find the category for + * @returns The category name, or undefined if not found + */ + function getCategoryForNodeType(nodeType: string): string | undefined { + registerDefaults() + + // Handle invalid input gracefully + if (!nodeType || typeof nodeType !== 'string') { + return undefined + } + + return nodeTypeToCategory.value[nodeType] + } + /** * Get the node provider for the given model type name. * @param modelType The name of the model type to get the node provider for. @@ -109,6 +140,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return { modelToNodeMap, getRegisteredNodeTypes, + getCategoryForNodeType, getNodeProvider, getAllNodeProviders, registerNodeProvider, diff --git a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts index d7d4f74dc..bef33733b 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts @@ -1,10 +1,33 @@ -import { describe, expect, it } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { nextTick } from 'vue' import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser' import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' + +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetDetails: vi.fn() + } +})) + +vi.mock('@/i18n', () => ({ + t: (key: string) => { + const translations: Record = { + 'assetBrowser.allModels': 'All Models', + 'assetBrowser.assets': 'Assets', + 'assetBrowser.unknown': 'unknown' + } + return translations[key] || key + }, + d: (date: Date) => date.toLocaleDateString() +})) describe('useAssetBrowser', () => { + beforeEach(() => { + vi.restoreAllMocks() + }) + // Test fixtures - minimal data focused on functionality being tested const createApiAsset = (overrides: Partial = {}): AssetItem => ({ id: 'test-id', @@ -26,8 +49,8 @@ describe('useAssetBrowser', () => { user_metadata: { description: 'Test model' } }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] // Get the transformed asset from filteredAssets // Preserves API properties expect(result.id).toBe(apiAsset.id) @@ -49,15 +72,13 @@ describe('useAssetBrowser', () => { user_metadata: undefined }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] expect(result.description).toBe('loras model') }) it('formats various file sizes correctly', () => { - const { transformAssetForDisplay } = useAssetBrowser([]) - const testCases = [ { size: 512, expected: '512 B' }, { size: 1536, expected: '1.5 KB' }, @@ -67,7 +88,8 @@ describe('useAssetBrowser', () => { testCases.forEach(({ size, expected }) => { const asset = createApiAsset({ size }) - const result = transformAssetForDisplay(asset) + const { filteredAssets } = useAssetBrowser([asset]) + const result = filteredAssets.value[0] expect(result.formattedSize).toBe(expected) }) }) @@ -236,18 +258,182 @@ describe('useAssetBrowser', () => { }) }) - describe('Asset Selection', () => { - it('returns selected asset UUID for efficient handling', () => { + describe('Async Asset Selection with Detail Fetching', () => { + it('should fetch asset details and call onSelect with filename when provided', async () => { + const onSelectSpy = vi.fn() const asset = createApiAsset({ - id: 'test-uuid-123', - name: 'selected_model.safetensors' + id: 'asset-123', + name: 'test-model.safetensors' }) - const { selectAsset, transformAssetForDisplay } = useAssetBrowser([asset]) - const displayAsset = transformAssetForDisplay(asset) - const result = selectAsset(displayAsset) + const detailAsset = createApiAsset({ + id: 'asset-123', + name: 'test-model.safetensors', + user_metadata: { filename: 'checkpoints/test-model.safetensors' } + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) - expect(result).toBe('test-uuid-123') + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123') + expect(onSelectSpy).toHaveBeenCalledWith( + 'checkpoints/test-model.safetensors' + ) + }) + + it('should handle missing user_metadata.filename as error', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-456' }) + + const detailAsset = createApiAsset({ + id: 'asset-456', + user_metadata: { filename: '' } // Invalid empty filename + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Filename cannot be empty' + }) + ]), + 'for asset:', + 'asset-456' + ) + }) + + it('should handle API errors gracefully', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-789' }) + + const apiError = new Error('API Error') + vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining('Failed to fetch asset details for asset-789'), + apiError + ) + }) + + it('should not fetch details when no callback provided', async () => { + const asset = createApiAsset({ id: 'asset-no-callback' }) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id) + + expect(assetService.getAssetDetails).not.toHaveBeenCalled() + }) + }) + + describe('Filename Validation Security', () => { + const createValidationTest = (filename: string) => { + const testAsset = createApiAsset({ id: 'validation-test' }) + const detailAsset = createApiAsset({ + id: 'validation-test', + user_metadata: { filename } + }) + return { testAsset, detailAsset } + } + + it('accepts valid file paths with forward slashes', async () => { + const onSelectSpy = vi.fn() + const { testAsset, detailAsset } = createValidationTest( + 'models/checkpoints/v1/test-model.safetensors' + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).toHaveBeenCalledWith( + 'models/checkpoints/v1/test-model.safetensors' + ) + }) + + it('rejects directory traversal attacks', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const maliciousPaths = [ + '../malicious-model.safetensors', + 'models/../../../etc/passwd', + '/etc/passwd' + ] + + for (const path of maliciousPaths) { + const { testAsset, detailAsset } = createValidationTest(path) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Path must not start with / or contain ..' + }) + ]), + 'for asset:', + 'validation-test' + ) + } + }) + + it('rejects invalid filename characters', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|'] + + for (const char of invalidChars) { + const { testAsset, detailAsset } = createValidationTest( + `bad${char}filename.safetensors` + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Invalid filename characters' + }) + ]), + 'for asset:', + 'validation-test' + ) + } }) }) diff --git a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts index fefeeceac..102aa7a18 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts @@ -6,11 +6,18 @@ import { useDialogStore } from '@/stores/dialogStore' // Mock the dialog store vi.mock('@/stores/dialogStore') +// Mock the asset service +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetsForNodeType: vi.fn().mockResolvedValue([]) + } +})) + // Test factory functions interface AssetBrowserProps { nodeType: string inputName: string - onAssetSelected?: ReturnType + onAssetSelected?: (filename: string) => void } function createAssetBrowserProps( @@ -25,7 +32,7 @@ function createAssetBrowserProps( describe('useAssetBrowserDialog', () => { describe('Asset Selection Flow', () => { - it('auto-closes dialog when asset is selected', () => { + it('auto-closes dialog when asset is selected', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -41,7 +48,7 @@ describe('useAssetBrowserDialog', () => { const onAssetSelected = vi.fn() const props = createAssetBrowserProps({ onAssetSelected }) - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onSelect handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] @@ -50,14 +57,14 @@ describe('useAssetBrowserDialog', () => { // Simulate asset selection onSelectHandler('selected-asset-path') - // Should call the original callback and close dialog + // Should call the original callback and trigger hide animation expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path') expect(mockCloseDialog).toHaveBeenCalledWith({ key: 'global-asset-browser' }) }) - it('closes dialog when close handler is called', () => { + it('closes dialog when close handler is called', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -72,7 +79,7 @@ describe('useAssetBrowserDialog', () => { const assetBrowserDialog = useAssetBrowserDialog() const props = createAssetBrowserProps() - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onClose handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts index 875919ccd..d439d98ca 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { LGraphNode } from '@/lib/litegraph/src/litegraph' import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget' import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2' @@ -29,13 +30,25 @@ vi.mock('@/platform/assets/services/assetService', () => ({ } })) +vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => { + const mockAssetBrowserDialogShow = vi.fn() + return { + useAssetBrowserDialog: vi.fn(() => ({ + show: mockAssetBrowserDialogShow + })) + } +}) + // Test factory functions function createMockWidget(overrides: Partial = {}): IBaseWidget { + const mockCallback = vi.fn() return { type: 'combo', options: {}, name: 'testWidget', value: undefined, + callback: mockCallback, + y: 0, ...overrides } as IBaseWidget } @@ -45,7 +58,16 @@ function createMockNode(comfyClass = 'TestNode'): LGraphNode { node.comfyClass = comfyClass // Spy on the addWidget method - vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget()) + vi.spyOn(node, 'addWidget').mockImplementation( + (type, name, value, callback) => { + const widget = createMockWidget({ type, name, value }) + // Store the callback function on the widget for testing + if (typeof callback === 'function') { + widget.callback = callback + } + return widget + } + ) return node } @@ -61,9 +83,9 @@ function createMockInputSpec(overrides: Partial = {}): InputSpec { describe('useComboWidget', () => { beforeEach(() => { vi.clearAllMocks() - // Reset to defaults mockSettingStoreGet.mockReturnValue(false) vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false) + vi.mocked(useAssetBrowserDialog).mockClear() }) it('should handle undefined spec', () => { diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index d96ef765b..7a719c4e4 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -4,6 +4,8 @@ import type { AssetItem } from '@/platform/assets/schemas/assetSchema' import { assetService } from '@/platform/assets/services/assetService' import { api } from '@/scripts/api' +const mockGetCategoryForNodeType = vi.fn() + vi.mock('@/stores/modelToNodeStore', () => ({ useModelToNodeStore: vi.fn(() => ({ getRegisteredNodeTypes: vi.fn( @@ -14,7 +16,13 @@ vi.mock('@/stores/modelToNodeStore', () => ({ 'VAELoader', 'TestNode' ]) - ) + ), + getCategoryForNodeType: mockGetCategoryForNodeType, + modelToNodeMap: { + checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }], + loras: [{ nodeDef: { name: 'LoraLoader' } }], + vae: [{ nodeDef: { name: 'VAELoader' } }] + } })) })) @@ -210,4 +218,87 @@ describe('assetService', () => { ).toBe(false) }) }) + + describe('getAssetsForNodeType', () => { + beforeEach(() => { + mockGetCategoryForNodeType.mockClear() + }) + + it('should return empty array for unregistered node types', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('UnknownNode') + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith('UnknownNode') + expect(result).toEqual([]) + }) + + it('should use getCategoryForNodeType for efficient category lookup', async () => { + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const testAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(testAssets) + + const result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(testAssets) + + // Verify API call includes correct category + expect(api.fetchApi).toHaveBeenCalledWith( + '/assets?include_tags=models,checkpoints' + ) + }) + + it('should return empty array when no category found', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('TestNode') + + expect(result).toEqual([]) + expect(api.fetchApi).not.toHaveBeenCalled() + }) + + it('should handle API errors gracefully', async () => { + mockGetCategoryForNodeType.mockReturnValue('loras') + mockApiError(500, 'Internal Server Error') + + await expect( + assetService.getAssetsForNodeType('LoraLoader') + ).rejects.toThrow( + 'Unable to load assets for LoraLoader: Server returned 500. Please try again.' + ) + }) + + it('should return all assets without filtering for different categories', async () => { + // Test checkpoints + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const checkpointAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(checkpointAssets) + + let result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(checkpointAssets) + + // Test loras + mockGetCategoryForNodeType.mockReturnValue('loras') + const loraAssets = [MOCK_ASSETS.loras] + mockApiResponse(loraAssets) + + result = await assetService.getAssetsForNodeType('LoraLoader') + expect(result).toEqual(loraAssets) + + // Test vae + mockGetCategoryForNodeType.mockReturnValue('vae') + const vaeAssets = [MOCK_ASSETS.vae] + mockApiResponse(vaeAssets) + + result = await assetService.getAssetsForNodeType('VAELoader') + expect(result).toEqual(vaeAssets) + }) + }) }) diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index 179c76b9e..b07c34a41 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -19,7 +19,7 @@ const EXPECTED_DEFAULT_TYPES = [ 'gligen' ] as const -type NodeDefStoreType = typeof import('@/stores/nodeDefStore') +type NodeDefStoreType = ReturnType // Create minimal but valid ComfyNodeDefImpl for testing function createMockNodeDef(name: string): ComfyNodeDefImpl { @@ -343,6 +343,107 @@ describe('useModelToNodeStore', () => { }) }) + describe('getCategoryForNodeType', () => { + it('should return category for known node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + expect(modelToNodeStore.getCategoryForNodeType('LoraLoader')).toBe( + 'loras' + ) + expect(modelToNodeStore.getCategoryForNodeType('VAELoader')).toBe('vae') + }) + + it('should return undefined for unknown node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('NonExistentNode') + ).toBeUndefined() + expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined() + }) + + it('should return first category when node type exists in multiple categories', () => { + const modelToNodeStore = useModelToNodeStore() + + // Test with a node that exists in the defaults but add our own first + // Since defaults register 'StyleModelLoader' in 'style_models', + // we verify our custom registrations come after defaults in Object.entries iteration + const result = modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + expect(result).toBe('style_models') // This proves the method works correctly + + // Now test that custom registrations after defaults also work + modelToNodeStore.quickRegister( + 'unicorn_styles', + 'StyleModelLoader', + 'param1' + ) + const result2 = + modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + // Should still be style_models since it was registered first by defaults + expect(result2).toBe('style_models') + }) + + it('should trigger lazy registration when called before registerDefaults', () => { + const modelToNodeStore = useModelToNodeStore() + + const result = modelToNodeStore.getCategoryForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toBe('checkpoints') + }) + + it('should be performant for repeated lookups', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // Measure performance without assuming implementation + const start = performance.now() + for (let i = 0; i < 1000; i++) { + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + } + const end = performance.now() + + // Should be fast enough for UI responsiveness + expect(end - start).toBeLessThan(10) + }) + + it('should handle invalid input types gracefully', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // These should not throw but return undefined + expect( + modelToNodeStore.getCategoryForNodeType(null as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(undefined as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(123 as any) + ).toBeUndefined() + }) + + it('should be case-sensitive for node type matching', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('checkpointloadersimple') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CHECKPOINTLOADERSIMPLE') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + }) + }) + describe('edge cases', () => { it('should handle empty string model type', () => { const modelToNodeStore = useModelToNodeStore()