diff --git a/package.json b/package.json
index 770ef7e04..a1089933f 100644
--- a/package.json
+++ b/package.json
@@ -27,6 +27,8 @@
"preview": "nx preview",
"lint": "eslint src --cache",
"lint:fix": "eslint src --cache --fix",
+ "lint:unstaged": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache",
+ "lint:unstaged:fix": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache --fix",
"lint:no-cache": "eslint src",
"lint:fix:no-cache": "eslint src --fix",
"knip": "knip --cache",
diff --git a/src/i18n.ts b/src/i18n.ts
index 102ac2600..38a8dfe95 100644
--- a/src/i18n.ts
+++ b/src/i18n.ts
@@ -76,7 +76,7 @@ export const i18n = createI18n({
})
/** Convenience shorthand: i18n.global */
-export const { t, te } = i18n.global
+export const { t, te, d } = i18n.global
/**
* Safe translation function that returns the fallback message if the key is not found.
diff --git a/src/lib/litegraph/src/litegraph.ts b/src/lib/litegraph/src/litegraph.ts
index 46202a219..46b094af0 100644
--- a/src/lib/litegraph/src/litegraph.ts
+++ b/src/lib/litegraph/src/litegraph.ts
@@ -140,7 +140,7 @@ export { BaseWidget } from './widgets/BaseWidget'
export { LegacyWidget } from './widgets/LegacyWidget'
-export { isComboWidget } from './widgets/widgetMap'
+export { isComboWidget, isAssetWidget } from './widgets/widgetMap'
// Additional test-specific exports
export { LGraphButton } from './LGraphButton'
export { MovingOutputLink } from './canvas/MovingOutputLink'
diff --git a/src/lib/litegraph/src/widgets/AssetWidget.ts b/src/lib/litegraph/src/widgets/AssetWidget.ts
index f8a8e1209..1a5047beb 100644
--- a/src/lib/litegraph/src/widgets/AssetWidget.ts
+++ b/src/lib/litegraph/src/widgets/AssetWidget.ts
@@ -13,6 +13,22 @@ export class AssetWidget
this.value = widget.value?.toString() ?? ''
}
+ override set value(value: IAssetWidget['value']) {
+ const oldValue = this.value
+ super.value = value
+
+ // Force canvas redraw when value changes to show update immediately
+ if (oldValue !== value && this.node.graph?.list_of_graphcanvas) {
+ for (const canvas of this.node.graph.list_of_graphcanvas) {
+ canvas.setDirty(true)
+ }
+ }
+ }
+
+ override get value(): IAssetWidget['value'] {
+ return super.value
+ }
+
override get _displayValue(): string {
return String(this.value) //FIXME: Resolve asset name
}
diff --git a/src/lib/litegraph/src/widgets/widgetMap.ts b/src/lib/litegraph/src/widgets/widgetMap.ts
index 02cdb5597..0e6a34fe5 100644
--- a/src/lib/litegraph/src/widgets/widgetMap.ts
+++ b/src/lib/litegraph/src/widgets/widgetMap.ts
@@ -1,5 +1,6 @@
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
import type {
+ IAssetWidget,
IBaseWidget,
IComboWidget,
IWidget,
@@ -132,4 +133,9 @@ export function isComboWidget(widget: IBaseWidget): widget is IComboWidget {
return widget.type === 'combo'
}
+/** Type guard: Narrow **from {@link IBaseWidget}** to {@link IAssetWidget}. */
+export function isAssetWidget(widget: IBaseWidget): widget is IAssetWidget {
+ return widget.type === 'asset'
+}
+
// #endregion Type Guards
diff --git a/src/locales/en/main.json b/src/locales/en/main.json
index 3abd85335..74f0352ae 100644
--- a/src/locales/en/main.json
+++ b/src/locales/en/main.json
@@ -1873,6 +1873,13 @@
"noModelsInFolder": "No {type} available in this folder",
"searchAssetsPlaceholder": "Search assets...",
"allModels": "All Models",
- "unknown": "Unknown"
+ "unknown": "Unknown",
+ "fileFormats": "File formats",
+ "baseModels": "Base models",
+ "sortBy": "Sort by",
+ "sortAZ": "A-Z",
+ "sortZA": "Z-A",
+ "sortRecent": "Recent",
+ "sortPopular": "Popular"
}
}
diff --git a/src/platform/assets/components/AssetBrowserModal.stories.ts b/src/platform/assets/components/AssetBrowserModal.stories.ts
index acc93181d..9d2321d57 100644
--- a/src/platform/assets/components/AssetBrowserModal.stories.ts
+++ b/src/platform/assets/components/AssetBrowserModal.stories.ts
@@ -1,6 +1,7 @@
import type { Meta, StoryObj } from '@storybook/vue3-vite'
import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue'
+import type { AssetDisplayItem } from '@/platform/assets/composables/useAssetBrowser'
import {
createMockAssets,
mockAssets
@@ -56,7 +57,7 @@ export const Default: Story = {
render: (args) => ({
components: { AssetBrowserModal },
setup() {
- const onAssetSelect = (asset: any) => {
+ const onAssetSelect = (asset: AssetDisplayItem) => {
console.log('Selected asset:', asset)
}
const onClose = () => {
@@ -96,7 +97,7 @@ export const SingleAssetType: Story = {
render: (args) => ({
components: { AssetBrowserModal },
setup() {
- const onAssetSelect = (asset: any) => {
+ const onAssetSelect = (asset: AssetDisplayItem) => {
console.log('Selected asset:', asset)
}
const onClose = () => {
@@ -145,7 +146,7 @@ export const NoLeftPanel: Story = {
render: (args) => ({
components: { AssetBrowserModal },
setup() {
- const onAssetSelect = (asset: any) => {
+ const onAssetSelect = (asset: AssetDisplayItem) => {
console.log('Selected asset:', asset)
}
const onClose = () => {
diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue
index de05f437d..cb45f38ba 100644
--- a/src/platform/assets/components/AssetBrowserModal.vue
+++ b/src/platform/assets/components/AssetBrowserModal.vue
@@ -12,7 +12,7 @@
:nav-items="availableCategories"
>
-
+
{{ $t('assetBrowser.browseAssets') }}
@@ -37,7 +37,7 @@
diff --git a/src/platform/assets/components/AssetCard.vue b/src/platform/assets/components/AssetCard.vue
index e379099c1..be7c45ca5 100644
--- a/src/platform/assets/components/AssetCard.vue
+++ b/src/platform/assets/components/AssetCard.vue
@@ -14,7 +14,7 @@
'bg-ivory-100 border border-gray-300 dark-theme:bg-charcoal-400 dark-theme:border-charcoal-600',
'hover:transform hover:-translate-y-0.5 hover:shadow-lg hover:shadow-black/10 hover:border-gray-400',
'dark-theme:hover:shadow-lg dark-theme:hover:shadow-black/30 dark-theme:hover:border-charcoal-700',
- 'focus:outline-none focus:ring-2 focus:ring-blue-500 dark-theme:focus:ring-blue-400'
+ 'focus:outline-none focus:transform focus:-translate-y-0.5 focus:shadow-lg focus:shadow-black/10 dark-theme:focus:shadow-black/30'
],
// Div-specific styles
!interactive && [
diff --git a/src/platform/assets/components/AssetFilterBar.vue b/src/platform/assets/components/AssetFilterBar.vue
index 1f3295b43..904ce3e82 100644
--- a/src/platform/assets/components/AssetFilterBar.vue
+++ b/src/platform/assets/components/AssetFilterBar.vue
@@ -3,7 +3,7 @@
void
+ ): Promise {
if (import.meta.env.DEV) {
- console.log('Asset selected:', asset.id, asset.name)
+ console.debug('Asset selected:', assetId)
+ }
+
+ if (!onSelect) {
+ return
+ }
+
+ try {
+ const detailAsset = await assetService.getAssetDetails(assetId)
+ const filename = detailAsset.user_metadata?.filename
+ const validatedFilename = assetFilenameSchema.safeParse(filename)
+ if (!validatedFilename.success) {
+ console.error(
+ 'Invalid asset filename:',
+ validatedFilename.error.errors,
+ 'for asset:',
+ assetId
+ )
+ return
+ }
+
+ onSelect(validatedFilename.data)
+ } catch (error) {
+ console.error(`Failed to fetch asset details for ${assetId}:`, error)
}
- return asset.id
}
return {
@@ -182,7 +212,6 @@ export function useAssetBrowser(assets: AssetItem[] = []) {
filteredAssets,
// Actions
- selectAsset,
- transformAssetForDisplay
+ selectAssetWithCallback
}
}
diff --git a/src/platform/assets/composables/useAssetBrowserDialog.ts b/src/platform/assets/composables/useAssetBrowserDialog.ts
index e5f63eead..31f75c353 100644
--- a/src/platform/assets/composables/useAssetBrowserDialog.ts
+++ b/src/platform/assets/composables/useAssetBrowserDialog.ts
@@ -1,5 +1,7 @@
import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue'
-import { useDialogStore } from '@/stores/dialogStore'
+import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
+import { assetService } from '@/platform/assets/services/assetService'
+import { type DialogComponentProps, useDialogStore } from '@/stores/dialogStore'
interface AssetBrowserDialogProps {
/** ComfyUI node type for context (e.g., 'CheckpointLoaderSimple') */
@@ -8,36 +10,29 @@ interface AssetBrowserDialogProps {
inputName: string
/** Current selected asset value */
currentValue?: string
- /** Callback for when an asset is selected */
- onAssetSelected?: (assetPath: string) => void
+ /**
+ * Callback for when an asset is selected
+ * @param {string} filename - The validated filename from user_metadata.filename
+ */
+ onAssetSelected?: (filename: string) => void
}
export const useAssetBrowserDialog = () => {
const dialogStore = useDialogStore()
const dialogKey = 'global-asset-browser'
- function hide() {
- dialogStore.closeDialog({ key: dialogKey })
- }
-
- function show(props: AssetBrowserDialogProps) {
- const handleAssetSelected = (assetPath: string) => {
- props.onAssetSelected?.(assetPath)
- hide() // Auto-close on selection
+ async function show(props: AssetBrowserDialogProps) {
+ const handleAssetSelected = (filename: string) => {
+ props.onAssetSelected?.(filename)
+ dialogStore.closeDialog({ key: dialogKey })
}
-
- const handleClose = () => {
- hide()
- }
-
- // Default dialog configuration for AssetBrowserModal
- const dialogComponentProps = {
+ const dialogComponentProps: DialogComponentProps = {
headless: true,
modal: true,
- closable: false,
+ closable: true,
pt: {
root: {
- class: 'rounded-2xl overflow-hidden'
+ class: 'rounded-2xl overflow-hidden asset-browser-dialog'
},
header: {
class: 'p-0 hidden'
@@ -48,6 +43,17 @@ export const useAssetBrowserDialog = () => {
}
}
+ const assets: AssetItem[] = await assetService
+ .getAssetsForNodeType(props.nodeType)
+ .catch((error) => {
+ console.error(
+ 'Failed to fetch assets for node type:',
+ props.nodeType,
+ error
+ )
+ return []
+ })
+
dialogStore.showDialog({
key: dialogKey,
component: AssetBrowserModal,
@@ -55,12 +61,13 @@ export const useAssetBrowserDialog = () => {
nodeType: props.nodeType,
inputName: props.inputName,
currentValue: props.currentValue,
+ assets,
onSelect: handleAssetSelected,
- onClose: handleClose
+ onClose: () => dialogStore.closeDialog({ key: dialogKey })
},
dialogComponentProps
})
}
- return { show, hide }
+ return { show }
}
diff --git a/src/platform/assets/schemas/assetSchema.ts b/src/platform/assets/schemas/assetSchema.ts
index fab41649a..2c051a30d 100644
--- a/src/platform/assets/schemas/assetSchema.ts
+++ b/src/platform/assets/schemas/assetSchema.ts
@@ -4,13 +4,13 @@ import { z } from 'zod'
const zAsset = z.object({
id: z.string(),
name: z.string(),
- asset_hash: z.string(),
+ asset_hash: z.string().nullable(),
size: z.number(),
- mime_type: z.string(),
+ mime_type: z.string().nullable(),
tags: z.array(z.string()),
preview_url: z.string().optional(),
created_at: z.string(),
- updated_at: z.string(),
+ updated_at: z.string().optional(),
last_access_time: z.string(),
user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs
preview_id: z.string().nullable().optional()
@@ -33,6 +33,14 @@ const zModelFile = z.object({
pathIndex: z.number()
})
+// Filename validation schema
+export const assetFilenameSchema = z
+ .string()
+ .min(1, 'Filename cannot be empty')
+ .regex(/^[^\\:*?"<>|]+$/, 'Invalid filename characters') // Allow forward slashes, block backslashes and other unsafe chars
+ .regex(/^(?!\/|.*\.\.)/, 'Path must not start with / or contain ..') // Prevent absolute paths and directory traversal
+ .trim()
+
// Export schemas following repository patterns
export const assetResponseSchema = zAssetResponse
diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts
index 74b20a753..7d0f82cbb 100644
--- a/src/platform/assets/services/assetService.ts
+++ b/src/platform/assets/services/assetService.ts
@@ -1,6 +1,7 @@
import { fromZodError } from 'zod-validation-error'
import {
+ type AssetItem,
type AssetResponse,
type ModelFile,
type ModelFolder,
@@ -127,10 +128,75 @@ function createAssetService() {
)
}
+ /**
+ * Gets assets for a specific node type by finding the matching category
+ * and fetching all assets with that category tag
+ *
+ * @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
+ * @returns Promise - Full asset objects with preserved metadata
+ */
+ async function getAssetsForNodeType(nodeType: string): Promise {
+ if (!nodeType || typeof nodeType !== 'string') {
+ return []
+ }
+
+ // Find the category for this node type using efficient O(1) lookup
+ const modelToNodeStore = useModelToNodeStore()
+ const category = modelToNodeStore.getCategoryForNodeType(nodeType)
+
+ if (!category) {
+ return []
+ }
+
+ // Fetch assets for this category using same API pattern as getAssetModels
+ const data = await handleAssetRequest(
+ `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}`,
+ `assets for ${nodeType}`
+ )
+
+ // Return full AssetItem[] objects (don't strip like getAssetModels does)
+ return (
+ data?.assets?.filter(
+ (asset) =>
+ !asset.tags.includes(MISSING_TAG) && asset.tags.includes(category)
+ ) ?? []
+ )
+ }
+
+ /**
+ * Gets complete details for a specific asset by ID
+ * Calls the detail endpoint which includes user_metadata and all fields
+ *
+ * @param id - The asset ID
+ * @returns Promise - Complete asset object with user_metadata
+ */
+ async function getAssetDetails(id: string): Promise {
+ const res = await api.fetchApi(`${ASSETS_ENDPOINT}/${id}`)
+ if (!res.ok) {
+ throw new Error(
+ `Unable to load asset details for ${id}: Server returned ${res.status}. Please try again.`
+ )
+ }
+ const data = await res.json()
+
+ // Validate the single asset response against our schema
+ const result = assetResponseSchema.safeParse({ assets: [data] })
+ if (result.success && result.data.assets?.[0]) {
+ return result.data.assets[0]
+ }
+
+ const error = result.error
+ ? fromZodError(result.error)
+ : 'Unknown validation error'
+ throw new Error(`Invalid asset response against zod schema:\n${error}`)
+ }
+
return {
getAssetModelFolders,
getAssetModels,
- isAssetBrowserEligible
+ isAssetBrowserEligible,
+ getAssetsForNodeType,
+ getAssetDetails
}
}
diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
index 2387fc59c..59705458c 100644
--- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
+++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts
@@ -3,10 +3,9 @@ import { ref } from 'vue'
import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue'
import { t } from '@/i18n'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
-import type {
- IBaseWidget,
- IComboWidget
-} from '@/lib/litegraph/src/types/widgets'
+import { isAssetWidget, isComboWidget } from '@/lib/litegraph/src/litegraph'
+import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
+import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
import { assetService } from '@/platform/assets/services/assetService'
import { useSettingStore } from '@/platform/settings/settingStore'
import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration'
@@ -73,11 +72,29 @@ const addComboWidget = (
const currentValue = getDefaultValue(inputSpec)
const displayLabel = currentValue ?? t('widgets.selectModel')
- const widget = node.addWidget('asset', inputSpec.name, displayLabel, () => {
- console.log(
- `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}`
- )
- })
+ const assetBrowserDialog = useAssetBrowserDialog()
+
+ const widget = node.addWidget(
+ 'asset',
+ inputSpec.name,
+ displayLabel,
+ async () => {
+ if (!isAssetWidget(widget)) {
+ throw new Error(`Expected asset widget but received ${widget.type}`)
+ }
+ await assetBrowserDialog.show({
+ nodeType: node.comfyClass || '',
+ inputName: inputSpec.name,
+ currentValue: widget.value,
+ onAssetSelected: (filename: string) => {
+ const oldValue = widget.value
+ widget.value = filename
+ // Using onWidgetChanged prevents a callback race where asset selection could reopen the dialog
+ node.onWidgetChanged?.(widget.name, filename, oldValue, widget)
+ }
+ })
+ }
+ )
return widget
}
@@ -96,11 +113,14 @@ const addComboWidget = (
)
if (inputSpec.remote) {
+ if (!isComboWidget(widget)) {
+ throw new Error(`Expected combo widget but received ${widget.type}`)
+ }
const remoteWidget = useRemoteWidget({
remoteConfig: inputSpec.remote,
defaultValue,
node,
- widget: widget as IComboWidget
+ widget
})
if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton()
@@ -116,16 +136,19 @@ const addComboWidget = (
}
if (inputSpec.control_after_generate) {
+ if (!isComboWidget(widget)) {
+ throw new Error(`Expected combo widget but received ${widget.type}`)
+ }
widget.linkedWidgets = addValueControlWidgets(
node,
- widget as IComboWidget,
+ widget,
undefined,
undefined,
transformInputSpecV2ToV1(inputSpec)
)
}
- return widget as IBaseWidget
+ return widget
}
export const useComboWidget = () => {
diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts
index f6c15e91a..4f7925294 100644
--- a/src/stores/modelToNodeStore.ts
+++ b/src/stores/modelToNodeStore.ts
@@ -33,12 +33,43 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
)
})
+ /** Internal computed for efficient reverse lookup: nodeType -> category */
+ const nodeTypeToCategory = computed(() => {
+ const lookup: Record = {}
+ for (const [category, providers] of Object.entries(modelToNodeMap.value)) {
+ for (const provider of providers) {
+ // Only store the first category for each node type (matches current assetService behavior)
+ if (!lookup[provider.nodeDef.name]) {
+ lookup[provider.nodeDef.name] = category
+ }
+ }
+ }
+ return lookup
+ })
+
/** Get set of all registered node types for efficient lookup */
function getRegisteredNodeTypes(): Set {
registerDefaults()
return registeredNodeTypes.value
}
+ /**
+ * Get the category for a given node type.
+ * Performs efficient O(1) lookup using cached reverse map.
+ * @param nodeType The node type name to find the category for
+ * @returns The category name, or undefined if not found
+ */
+ function getCategoryForNodeType(nodeType: string): string | undefined {
+ registerDefaults()
+
+ // Handle invalid input gracefully
+ if (!nodeType || typeof nodeType !== 'string') {
+ return undefined
+ }
+
+ return nodeTypeToCategory.value[nodeType]
+ }
+
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
@@ -109,6 +140,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
return {
modelToNodeMap,
getRegisteredNodeTypes,
+ getCategoryForNodeType,
getNodeProvider,
getAllNodeProviders,
registerNodeProvider,
diff --git a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
index d7d4f74dc..bef33733b 100644
--- a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
+++ b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts
@@ -1,10 +1,33 @@
-import { describe, expect, it } from 'vitest'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
import { nextTick } from 'vue'
import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
+import { assetService } from '@/platform/assets/services/assetService'
+
+vi.mock('@/platform/assets/services/assetService', () => ({
+ assetService: {
+ getAssetDetails: vi.fn()
+ }
+}))
+
+vi.mock('@/i18n', () => ({
+ t: (key: string) => {
+ const translations: Record = {
+ 'assetBrowser.allModels': 'All Models',
+ 'assetBrowser.assets': 'Assets',
+ 'assetBrowser.unknown': 'unknown'
+ }
+ return translations[key] || key
+ },
+ d: (date: Date) => date.toLocaleDateString()
+}))
describe('useAssetBrowser', () => {
+ beforeEach(() => {
+ vi.restoreAllMocks()
+ })
+
// Test fixtures - minimal data focused on functionality being tested
const createApiAsset = (overrides: Partial = {}): AssetItem => ({
id: 'test-id',
@@ -26,8 +49,8 @@ describe('useAssetBrowser', () => {
user_metadata: { description: 'Test model' }
})
- const { transformAssetForDisplay } = useAssetBrowser([apiAsset])
- const result = transformAssetForDisplay(apiAsset)
+ const { filteredAssets } = useAssetBrowser([apiAsset])
+ const result = filteredAssets.value[0] // Get the transformed asset from filteredAssets
// Preserves API properties
expect(result.id).toBe(apiAsset.id)
@@ -49,15 +72,13 @@ describe('useAssetBrowser', () => {
user_metadata: undefined
})
- const { transformAssetForDisplay } = useAssetBrowser([apiAsset])
- const result = transformAssetForDisplay(apiAsset)
+ const { filteredAssets } = useAssetBrowser([apiAsset])
+ const result = filteredAssets.value[0]
expect(result.description).toBe('loras model')
})
it('formats various file sizes correctly', () => {
- const { transformAssetForDisplay } = useAssetBrowser([])
-
const testCases = [
{ size: 512, expected: '512 B' },
{ size: 1536, expected: '1.5 KB' },
@@ -67,7 +88,8 @@ describe('useAssetBrowser', () => {
testCases.forEach(({ size, expected }) => {
const asset = createApiAsset({ size })
- const result = transformAssetForDisplay(asset)
+ const { filteredAssets } = useAssetBrowser([asset])
+ const result = filteredAssets.value[0]
expect(result.formattedSize).toBe(expected)
})
})
@@ -236,18 +258,182 @@ describe('useAssetBrowser', () => {
})
})
- describe('Asset Selection', () => {
- it('returns selected asset UUID for efficient handling', () => {
+ describe('Async Asset Selection with Detail Fetching', () => {
+ it('should fetch asset details and call onSelect with filename when provided', async () => {
+ const onSelectSpy = vi.fn()
const asset = createApiAsset({
- id: 'test-uuid-123',
- name: 'selected_model.safetensors'
+ id: 'asset-123',
+ name: 'test-model.safetensors'
})
- const { selectAsset, transformAssetForDisplay } = useAssetBrowser([asset])
- const displayAsset = transformAssetForDisplay(asset)
- const result = selectAsset(displayAsset)
+ const detailAsset = createApiAsset({
+ id: 'asset-123',
+ name: 'test-model.safetensors',
+ user_metadata: { filename: 'checkpoints/test-model.safetensors' }
+ })
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
- expect(result).toBe('test-uuid-123')
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123')
+ expect(onSelectSpy).toHaveBeenCalledWith(
+ 'checkpoints/test-model.safetensors'
+ )
+ })
+
+ it('should handle missing user_metadata.filename as error', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+ const asset = createApiAsset({ id: 'asset-456' })
+
+ const detailAsset = createApiAsset({
+ id: 'asset-456',
+ user_metadata: { filename: '' } // Invalid empty filename
+ })
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456')
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Filename cannot be empty'
+ })
+ ]),
+ 'for asset:',
+ 'asset-456'
+ )
+ })
+
+ it('should handle API errors gracefully', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+ const asset = createApiAsset({ id: 'asset-789' })
+
+ const apiError = new Error('API Error')
+ vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError)
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id, onSelectSpy)
+
+ expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789')
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ expect.stringContaining('Failed to fetch asset details for asset-789'),
+ apiError
+ )
+ })
+
+ it('should not fetch details when no callback provided', async () => {
+ const asset = createApiAsset({ id: 'asset-no-callback' })
+
+ const { selectAssetWithCallback } = useAssetBrowser([asset])
+
+ await selectAssetWithCallback(asset.id)
+
+ expect(assetService.getAssetDetails).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('Filename Validation Security', () => {
+ const createValidationTest = (filename: string) => {
+ const testAsset = createApiAsset({ id: 'validation-test' })
+ const detailAsset = createApiAsset({
+ id: 'validation-test',
+ user_metadata: { filename }
+ })
+ return { testAsset, detailAsset }
+ }
+
+ it('accepts valid file paths with forward slashes', async () => {
+ const onSelectSpy = vi.fn()
+ const { testAsset, detailAsset } = createValidationTest(
+ 'models/checkpoints/v1/test-model.safetensors'
+ )
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).toHaveBeenCalledWith(
+ 'models/checkpoints/v1/test-model.safetensors'
+ )
+ })
+
+ it('rejects directory traversal attacks', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+
+ const maliciousPaths = [
+ '../malicious-model.safetensors',
+ 'models/../../../etc/passwd',
+ '/etc/passwd'
+ ]
+
+ for (const path of maliciousPaths) {
+ const { testAsset, detailAsset } = createValidationTest(path)
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Path must not start with / or contain ..'
+ })
+ ]),
+ 'for asset:',
+ 'validation-test'
+ )
+ }
+ })
+
+ it('rejects invalid filename characters', async () => {
+ const consoleErrorSpy = vi
+ .spyOn(console, 'error')
+ .mockImplementation(() => {})
+ const onSelectSpy = vi.fn()
+
+ const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|']
+
+ for (const char of invalidChars) {
+ const { testAsset, detailAsset } = createValidationTest(
+ `bad${char}filename.safetensors`
+ )
+ vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
+
+ const { selectAssetWithCallback } = useAssetBrowser([testAsset])
+ await selectAssetWithCallback(testAsset.id, onSelectSpy)
+
+ expect(onSelectSpy).not.toHaveBeenCalled()
+ expect(consoleErrorSpy).toHaveBeenCalledWith(
+ 'Invalid asset filename:',
+ expect.arrayContaining([
+ expect.objectContaining({
+ message: 'Invalid filename characters'
+ })
+ ]),
+ 'for asset:',
+ 'validation-test'
+ )
+ }
})
})
diff --git a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
index fefeeceac..102aa7a18 100644
--- a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
+++ b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts
@@ -6,11 +6,18 @@ import { useDialogStore } from '@/stores/dialogStore'
// Mock the dialog store
vi.mock('@/stores/dialogStore')
+// Mock the asset service
+vi.mock('@/platform/assets/services/assetService', () => ({
+ assetService: {
+ getAssetsForNodeType: vi.fn().mockResolvedValue([])
+ }
+}))
+
// Test factory functions
interface AssetBrowserProps {
nodeType: string
inputName: string
- onAssetSelected?: ReturnType
+ onAssetSelected?: (filename: string) => void
}
function createAssetBrowserProps(
@@ -25,7 +32,7 @@ function createAssetBrowserProps(
describe('useAssetBrowserDialog', () => {
describe('Asset Selection Flow', () => {
- it('auto-closes dialog when asset is selected', () => {
+ it('auto-closes dialog when asset is selected', async () => {
// Create fresh mocks for this test
const mockShowDialog = vi.fn()
const mockCloseDialog = vi.fn()
@@ -41,7 +48,7 @@ describe('useAssetBrowserDialog', () => {
const onAssetSelected = vi.fn()
const props = createAssetBrowserProps({ onAssetSelected })
- assetBrowserDialog.show(props)
+ await assetBrowserDialog.show(props)
// Get the onSelect handler that was passed to the dialog
const dialogCall = mockShowDialog.mock.calls[0][0]
@@ -50,14 +57,14 @@ describe('useAssetBrowserDialog', () => {
// Simulate asset selection
onSelectHandler('selected-asset-path')
- // Should call the original callback and close dialog
+ // Should call the original callback and trigger hide animation
expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path')
expect(mockCloseDialog).toHaveBeenCalledWith({
key: 'global-asset-browser'
})
})
- it('closes dialog when close handler is called', () => {
+ it('closes dialog when close handler is called', async () => {
// Create fresh mocks for this test
const mockShowDialog = vi.fn()
const mockCloseDialog = vi.fn()
@@ -72,7 +79,7 @@ describe('useAssetBrowserDialog', () => {
const assetBrowserDialog = useAssetBrowserDialog()
const props = createAssetBrowserProps()
- assetBrowserDialog.show(props)
+ await assetBrowserDialog.show(props)
// Get the onClose handler that was passed to the dialog
const dialogCall = mockShowDialog.mock.calls[0][0]
diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
index 875919ccd..d439d98ca 100644
--- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
+++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts
@@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
+import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
import { assetService } from '@/platform/assets/services/assetService'
import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget'
import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
@@ -29,13 +30,25 @@ vi.mock('@/platform/assets/services/assetService', () => ({
}
}))
+vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => {
+ const mockAssetBrowserDialogShow = vi.fn()
+ return {
+ useAssetBrowserDialog: vi.fn(() => ({
+ show: mockAssetBrowserDialogShow
+ }))
+ }
+})
+
// Test factory functions
function createMockWidget(overrides: Partial = {}): IBaseWidget {
+ const mockCallback = vi.fn()
return {
type: 'combo',
options: {},
name: 'testWidget',
value: undefined,
+ callback: mockCallback,
+ y: 0,
...overrides
} as IBaseWidget
}
@@ -45,7 +58,16 @@ function createMockNode(comfyClass = 'TestNode'): LGraphNode {
node.comfyClass = comfyClass
// Spy on the addWidget method
- vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget())
+ vi.spyOn(node, 'addWidget').mockImplementation(
+ (type, name, value, callback) => {
+ const widget = createMockWidget({ type, name, value })
+ // Store the callback function on the widget for testing
+ if (typeof callback === 'function') {
+ widget.callback = callback
+ }
+ return widget
+ }
+ )
return node
}
@@ -61,9 +83,9 @@ function createMockInputSpec(overrides: Partial = {}): InputSpec {
describe('useComboWidget', () => {
beforeEach(() => {
vi.clearAllMocks()
- // Reset to defaults
mockSettingStoreGet.mockReturnValue(false)
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false)
+ vi.mocked(useAssetBrowserDialog).mockClear()
})
it('should handle undefined spec', () => {
diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts
index d96ef765b..7a719c4e4 100644
--- a/tests-ui/tests/services/assetService.test.ts
+++ b/tests-ui/tests/services/assetService.test.ts
@@ -4,6 +4,8 @@ import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { assetService } from '@/platform/assets/services/assetService'
import { api } from '@/scripts/api'
+const mockGetCategoryForNodeType = vi.fn()
+
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: vi.fn(() => ({
getRegisteredNodeTypes: vi.fn(
@@ -14,7 +16,13 @@ vi.mock('@/stores/modelToNodeStore', () => ({
'VAELoader',
'TestNode'
])
- )
+ ),
+ getCategoryForNodeType: mockGetCategoryForNodeType,
+ modelToNodeMap: {
+ checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }],
+ loras: [{ nodeDef: { name: 'LoraLoader' } }],
+ vae: [{ nodeDef: { name: 'VAELoader' } }]
+ }
}))
}))
@@ -210,4 +218,87 @@ describe('assetService', () => {
).toBe(false)
})
})
+
+ describe('getAssetsForNodeType', () => {
+ beforeEach(() => {
+ mockGetCategoryForNodeType.mockClear()
+ })
+
+ it('should return empty array for unregistered node types', async () => {
+ mockGetCategoryForNodeType.mockReturnValue(undefined)
+
+ const result = await assetService.getAssetsForNodeType('UnknownNode')
+
+ expect(mockGetCategoryForNodeType).toHaveBeenCalledWith('UnknownNode')
+ expect(result).toEqual([])
+ })
+
+ it('should use getCategoryForNodeType for efficient category lookup', async () => {
+ mockGetCategoryForNodeType.mockReturnValue('checkpoints')
+ const testAssets = [MOCK_ASSETS.checkpoints]
+ mockApiResponse(testAssets)
+
+ const result = await assetService.getAssetsForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+
+ expect(mockGetCategoryForNodeType).toHaveBeenCalledWith(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toEqual(testAssets)
+
+ // Verify API call includes correct category
+ expect(api.fetchApi).toHaveBeenCalledWith(
+ '/assets?include_tags=models,checkpoints'
+ )
+ })
+
+ it('should return empty array when no category found', async () => {
+ mockGetCategoryForNodeType.mockReturnValue(undefined)
+
+ const result = await assetService.getAssetsForNodeType('TestNode')
+
+ expect(result).toEqual([])
+ expect(api.fetchApi).not.toHaveBeenCalled()
+ })
+
+ it('should handle API errors gracefully', async () => {
+ mockGetCategoryForNodeType.mockReturnValue('loras')
+ mockApiError(500, 'Internal Server Error')
+
+ await expect(
+ assetService.getAssetsForNodeType('LoraLoader')
+ ).rejects.toThrow(
+ 'Unable to load assets for LoraLoader: Server returned 500. Please try again.'
+ )
+ })
+
+ it('should return all assets without filtering for different categories', async () => {
+ // Test checkpoints
+ mockGetCategoryForNodeType.mockReturnValue('checkpoints')
+ const checkpointAssets = [MOCK_ASSETS.checkpoints]
+ mockApiResponse(checkpointAssets)
+
+ let result = await assetService.getAssetsForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toEqual(checkpointAssets)
+
+ // Test loras
+ mockGetCategoryForNodeType.mockReturnValue('loras')
+ const loraAssets = [MOCK_ASSETS.loras]
+ mockApiResponse(loraAssets)
+
+ result = await assetService.getAssetsForNodeType('LoraLoader')
+ expect(result).toEqual(loraAssets)
+
+ // Test vae
+ mockGetCategoryForNodeType.mockReturnValue('vae')
+ const vaeAssets = [MOCK_ASSETS.vae]
+ mockApiResponse(vaeAssets)
+
+ result = await assetService.getAssetsForNodeType('VAELoader')
+ expect(result).toEqual(vaeAssets)
+ })
+ })
})
diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts
index 179c76b9e..b07c34a41 100644
--- a/tests-ui/tests/store/modelToNodeStore.test.ts
+++ b/tests-ui/tests/store/modelToNodeStore.test.ts
@@ -19,7 +19,7 @@ const EXPECTED_DEFAULT_TYPES = [
'gligen'
] as const
-type NodeDefStoreType = typeof import('@/stores/nodeDefStore')
+type NodeDefStoreType = ReturnType
// Create minimal but valid ComfyNodeDefImpl for testing
function createMockNodeDef(name: string): ComfyNodeDefImpl {
@@ -343,6 +343,107 @@ describe('useModelToNodeStore', () => {
})
})
+ describe('getCategoryForNodeType', () => {
+ it('should return category for known node type', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ ).toBe('checkpoints')
+ expect(modelToNodeStore.getCategoryForNodeType('LoraLoader')).toBe(
+ 'loras'
+ )
+ expect(modelToNodeStore.getCategoryForNodeType('VAELoader')).toBe('vae')
+ })
+
+ it('should return undefined for unknown node type', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('NonExistentNode')
+ ).toBeUndefined()
+ expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined()
+ })
+
+ it('should return first category when node type exists in multiple categories', () => {
+ const modelToNodeStore = useModelToNodeStore()
+
+ // Test with a node that exists in the defaults but add our own first
+ // Since defaults register 'StyleModelLoader' in 'style_models',
+ // we verify our custom registrations come after defaults in Object.entries iteration
+ const result = modelToNodeStore.getCategoryForNodeType('StyleModelLoader')
+ expect(result).toBe('style_models') // This proves the method works correctly
+
+ // Now test that custom registrations after defaults also work
+ modelToNodeStore.quickRegister(
+ 'unicorn_styles',
+ 'StyleModelLoader',
+ 'param1'
+ )
+ const result2 =
+ modelToNodeStore.getCategoryForNodeType('StyleModelLoader')
+ // Should still be style_models since it was registered first by defaults
+ expect(result2).toBe('style_models')
+ })
+
+ it('should trigger lazy registration when called before registerDefaults', () => {
+ const modelToNodeStore = useModelToNodeStore()
+
+ const result = modelToNodeStore.getCategoryForNodeType(
+ 'CheckpointLoaderSimple'
+ )
+ expect(result).toBe('checkpoints')
+ })
+
+ it('should be performant for repeated lookups', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ // Measure performance without assuming implementation
+ const start = performance.now()
+ for (let i = 0; i < 1000; i++) {
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ }
+ const end = performance.now()
+
+ // Should be fast enough for UI responsiveness
+ expect(end - start).toBeLessThan(10)
+ })
+
+ it('should handle invalid input types gracefully', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ // These should not throw but return undefined
+ expect(
+ modelToNodeStore.getCategoryForNodeType(null as any)
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType(undefined as any)
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType(123 as any)
+ ).toBeUndefined()
+ })
+
+ it('should be case-sensitive for node type matching', () => {
+ const modelToNodeStore = useModelToNodeStore()
+ modelToNodeStore.registerDefaults()
+
+ expect(
+ modelToNodeStore.getCategoryForNodeType('checkpointloadersimple')
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CHECKPOINTLOADERSIMPLE')
+ ).toBeUndefined()
+ expect(
+ modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple')
+ ).toBe('checkpoints')
+ })
+ })
+
describe('edge cases', () => {
it('should handle empty string model type', () => {
const modelToNodeStore = useModelToNodeStore()