mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-04-29 02:32:18 +00:00
Full Asset Selection Experience (Assets API) (#5900)
## Summary Full Integration of Asset Browsing and Selection when Assets API is enabled. ## Changes 1. Replace Model Left Side Tab with experience 2. Configurable titles for the Asset Browser Modal 3. Refactors to simplify callback code 4. Refactor to make modal filters reactive (they change their values based on assets displayed) 5. Add `browse()` mode with ability to create node directly from the Asset Browser Modal (in `browse()` mode) ## Screenshots Demo of many different types of Nodes getting configured by the Modal https://github.com/user-attachments/assets/34f9c964-cdf2-4c5d-86a9-a8e7126a7de9 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-5900-Feat-asset-selection-cloud-integration-2816d73d365081ccb4aeecdc14b0e5d3) by [Unito](https://www.unito.io)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue'
|
||||
import type { AssetDisplayItem } from '@/platform/assets/composables/useAssetBrowser'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
|
||||
// Mock @/i18n for useAssetBrowser and AssetFilterBar
|
||||
@@ -57,6 +55,9 @@ vi.mock('@/components/widget/layout/BaseModalLayout.vue', () => ({
|
||||
<div data-testid="header">
|
||||
<slot name="header" />
|
||||
</div>
|
||||
<div v-if="$slots.contentFilter" data-testid="content-filter">
|
||||
<slot name="contentFilter" />
|
||||
</div>
|
||||
<div data-testid="content">
|
||||
<slot name="content" />
|
||||
</div>
|
||||
@@ -72,6 +73,9 @@ vi.mock('@/components/widget/panel/LeftSidePanel.vue', () => ({
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<div data-testid="left-side-panel">
|
||||
<div v-if="$slots['header-title']" data-testid="header-title">
|
||||
<slot name="header-title" />
|
||||
</div>
|
||||
<button
|
||||
v-for="item in navItems"
|
||||
:key="item.id"
|
||||
@@ -86,6 +90,19 @@ vi.mock('@/components/widget/panel/LeftSidePanel.vue', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/components/AssetFilterBar.vue', () => ({
|
||||
default: {
|
||||
name: 'AssetFilterBar',
|
||||
props: ['assets'],
|
||||
emits: ['filter-change'],
|
||||
template: `
|
||||
<div data-testid="asset-filter-bar">
|
||||
Filter bar with {{ assets?.length ?? 0 }} assets
|
||||
</div>
|
||||
`
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/components/AssetGrid.vue', () => ({
|
||||
default: {
|
||||
name: 'AssetGrid',
|
||||
@@ -169,99 +186,33 @@ describe('AssetBrowserModal', () => {
|
||||
})
|
||||
}
|
||||
|
||||
describe('Search Functionality', () => {
|
||||
it('filters assets when search query changes', async () => {
|
||||
describe('Integration with useAssetBrowser', () => {
|
||||
it('passes filteredAssets from composable to AssetGrid', () => {
|
||||
const assets = [
|
||||
createTestAsset('asset1', 'Checkpoint Model A', 'checkpoints'),
|
||||
createTestAsset('asset2', 'Checkpoint Model B', 'checkpoints'),
|
||||
createTestAsset('asset3', 'LoRA Model C', 'loras')
|
||||
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
||||
createTestAsset('asset2', 'Model B', 'loras')
|
||||
]
|
||||
const wrapper = createWrapper(assets)
|
||||
|
||||
const searchBox = wrapper.find('[data-testid="search-box"]')
|
||||
|
||||
// Search for "Checkpoint"
|
||||
await searchBox.setValue('Checkpoint')
|
||||
await nextTick()
|
||||
|
||||
// Should filter to only checkpoint assets
|
||||
const assetGrid = wrapper.findComponent({ name: 'AssetGrid' })
|
||||
const filteredAssets = assetGrid.props('assets') as AssetDisplayItem[]
|
||||
|
||||
expect(filteredAssets.length).toBe(2)
|
||||
expect(
|
||||
filteredAssets.every((asset: AssetDisplayItem) =>
|
||||
asset.name.includes('Checkpoint')
|
||||
)
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('search is case insensitive', async () => {
|
||||
const assets = [
|
||||
createTestAsset('asset1', 'LoRA Model C', 'loras'),
|
||||
createTestAsset('asset2', 'Checkpoint Model', 'checkpoints')
|
||||
]
|
||||
const wrapper = createWrapper(assets)
|
||||
|
||||
const searchBox = wrapper.find('[data-testid="search-box"]')
|
||||
|
||||
// Search with different case
|
||||
await searchBox.setValue('lora')
|
||||
await nextTick()
|
||||
|
||||
const assetGrid = wrapper.findComponent({ name: 'AssetGrid' })
|
||||
const filteredAssets = assetGrid.props('assets') as AssetDisplayItem[]
|
||||
const gridAssets = assetGrid.props('assets')
|
||||
|
||||
expect(filteredAssets.length).toBe(1)
|
||||
expect(filteredAssets[0].name).toContain('LoRA')
|
||||
expect(gridAssets).toHaveLength(2)
|
||||
expect(gridAssets[0].id).toBe('asset1')
|
||||
})
|
||||
|
||||
it('shows empty state when search has no results', async () => {
|
||||
it('passes categoryFilteredAssets to AssetFilterBar', () => {
|
||||
const assets = [
|
||||
createTestAsset('asset1', 'Checkpoint Model', 'checkpoints')
|
||||
]
|
||||
const wrapper = createWrapper(assets)
|
||||
|
||||
const searchBox = wrapper.find('[data-testid="search-box"]')
|
||||
|
||||
// Search for something that doesn't exist
|
||||
await searchBox.setValue('nonexistent')
|
||||
await nextTick()
|
||||
|
||||
expect(wrapper.find('[data-testid="empty-state"]').exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Category Navigation', () => {
|
||||
it('filters assets by selected category', async () => {
|
||||
const assets = [
|
||||
createTestAsset('asset1', 'Checkpoint Model A', 'checkpoints'),
|
||||
createTestAsset('asset2', 'LoRA Model C', 'loras'),
|
||||
createTestAsset('asset3', 'VAE Model D', 'vae')
|
||||
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
|
||||
createTestAsset('l1', 'lora.pt', 'loras')
|
||||
]
|
||||
const wrapper = createWrapper(assets, { showLeftPanel: true })
|
||||
|
||||
// Wait for Vue reactivity and component mounting
|
||||
await nextTick()
|
||||
const filterBar = wrapper.findComponent({ name: 'AssetFilterBar' })
|
||||
const filterBarAssets = filterBar.props('assets')
|
||||
|
||||
// Check if left panel exists first (since we have multiple categories)
|
||||
const leftPanel = wrapper.find('[data-testid="left-panel"]')
|
||||
expect(leftPanel.exists()).toBe(true)
|
||||
|
||||
// Check if the nav item exists before clicking
|
||||
const lorasNavItem = wrapper.find('[data-testid="nav-item-loras"]')
|
||||
expect(lorasNavItem.exists()).toBe(true)
|
||||
|
||||
// Click the loras category
|
||||
await lorasNavItem.trigger('click')
|
||||
await nextTick()
|
||||
|
||||
// Should filter to only LoRA assets
|
||||
const assetGrid = wrapper.findComponent({ name: 'AssetGrid' })
|
||||
const filteredAssets = assetGrid.props('assets') as AssetDisplayItem[]
|
||||
|
||||
expect(filteredAssets.length).toBe(1)
|
||||
expect(filteredAssets[0].name).toContain('LoRA')
|
||||
// Should initially show all assets
|
||||
expect(filterBarAssets).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,7 +228,7 @@ describe('AssetBrowserModal', () => {
|
||||
expect(emitted).toBeDefined()
|
||||
expect(emitted).toHaveLength(1)
|
||||
|
||||
const emittedAsset = emitted![0][0] as AssetDisplayItem
|
||||
const emittedAsset = emitted![0][0] as AssetItem
|
||||
expect(emittedAsset.id).toBe('asset1')
|
||||
})
|
||||
|
||||
@@ -289,7 +240,12 @@ describe('AssetBrowserModal', () => {
|
||||
// Click on first asset
|
||||
await wrapper.find('[data-testid="asset-asset1"]').trigger('click')
|
||||
|
||||
expect(onSelectSpy).toHaveBeenCalledWith('Test Model')
|
||||
expect(onSelectSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'asset1',
|
||||
name: 'Test Model'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -327,4 +283,56 @@ describe('AssetBrowserModal', () => {
|
||||
expect(wrapper2.find('[data-testid="left-panel"]').exists()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Filter Options Reactivity', () => {
|
||||
it('updates filter options when category changes', async () => {
|
||||
const assets = [
|
||||
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
|
||||
createTestAsset('c2', 'another.safetensors', 'checkpoints'),
|
||||
createTestAsset('l1', 'lora.pt', 'loras')
|
||||
]
|
||||
const wrapper = createWrapper(assets, { showLeftPanel: true })
|
||||
|
||||
// Initially on "all" category - should have both .safetensors and .pt
|
||||
const filterBar = wrapper.findComponent({ name: 'AssetFilterBar' })
|
||||
expect(filterBar.exists()).toBe(true)
|
||||
|
||||
// Switch to checkpoints category
|
||||
const checkpointsNav = wrapper.find(
|
||||
'[data-testid="nav-item-checkpoints"]'
|
||||
)
|
||||
expect(checkpointsNav.exists()).toBe(true)
|
||||
await checkpointsNav.trigger('click')
|
||||
|
||||
// Filter bar should receive only checkpoint assets now
|
||||
const updatedFilterBar = wrapper.findComponent({ name: 'AssetFilterBar' })
|
||||
const filterBarAssets = updatedFilterBar.props('assets')
|
||||
|
||||
expect(filterBarAssets).toHaveLength(2)
|
||||
expect(
|
||||
filterBarAssets.every((a: AssetItem) => a.tags.includes('checkpoints'))
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Title Management', () => {
|
||||
it('passes custom title to BaseModalLayout when title prop provided', () => {
|
||||
const assets = [createTestAsset('asset1', 'Test Model', 'checkpoints')]
|
||||
const customTitle = 'Model Library'
|
||||
const wrapper = createWrapper(assets, { title: customTitle })
|
||||
|
||||
const baseModal = wrapper.findComponent({ name: 'BaseModalLayout' })
|
||||
expect(baseModal.props('contentTitle')).toBe(customTitle)
|
||||
})
|
||||
|
||||
it('passes computed contentTitle to BaseModalLayout when no title prop', () => {
|
||||
const assets = [createTestAsset('asset1', 'Test Model', 'checkpoints')]
|
||||
const wrapper = createWrapper(assets)
|
||||
|
||||
const baseModal = wrapper.findComponent({ name: 'BaseModalLayout' })
|
||||
// Should use contentTitle from useAssetBrowser (e.g., "All Models")
|
||||
expect(baseModal.props('contentTitle')).toBeTruthy()
|
||||
expect(baseModal.props('contentTitle')).not.toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,13 +3,6 @@ import { nextTick } from 'vue'
|
||||
|
||||
import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetDetails: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string) => {
|
||||
@@ -42,6 +35,38 @@ describe('useAssetBrowser', () => {
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('Category Filtering', () => {
|
||||
it('exposes category-filtered assets for filter options', () => {
|
||||
const checkpointAsset = createApiAsset({
|
||||
id: 'checkpoint-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
})
|
||||
const loraAsset = createApiAsset({
|
||||
id: 'lora-1',
|
||||
name: 'lora.pt',
|
||||
tags: ['models', 'loras']
|
||||
})
|
||||
|
||||
const { selectedCategory, categoryFilteredAssets } = useAssetBrowser([
|
||||
checkpointAsset,
|
||||
loraAsset
|
||||
])
|
||||
|
||||
// Initially should show all assets
|
||||
expect(categoryFilteredAssets.value).toHaveLength(2)
|
||||
|
||||
// When category selected, should only show that category
|
||||
selectedCategory.value = 'checkpoints'
|
||||
expect(categoryFilteredAssets.value).toHaveLength(1)
|
||||
expect(categoryFilteredAssets.value[0].id).toBe('checkpoint-1')
|
||||
|
||||
selectedCategory.value = 'loras'
|
||||
expect(categoryFilteredAssets.value).toHaveLength(1)
|
||||
expect(categoryFilteredAssets.value[0].id).toBe('lora-1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Asset Transformation', () => {
|
||||
it('transforms API asset to include display properties', () => {
|
||||
const apiAsset = createApiAsset({
|
||||
@@ -258,185 +283,6 @@ describe('useAssetBrowser', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Async Asset Selection with Detail Fetching', () => {
|
||||
it('should fetch asset details and call onSelect with filename when provided', async () => {
|
||||
const onSelectSpy = vi.fn()
|
||||
const asset = createApiAsset({
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors'
|
||||
})
|
||||
|
||||
const detailAsset = createApiAsset({
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
user_metadata: { filename: 'checkpoints/test-model.safetensors' }
|
||||
})
|
||||
vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([asset])
|
||||
|
||||
await selectAssetWithCallback(asset.id, onSelectSpy)
|
||||
|
||||
expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123')
|
||||
expect(onSelectSpy).toHaveBeenCalledWith(
|
||||
'checkpoints/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle missing user_metadata.filename as error', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
const onSelectSpy = vi.fn()
|
||||
const asset = createApiAsset({ id: 'asset-456' })
|
||||
|
||||
const detailAsset = createApiAsset({
|
||||
id: 'asset-456',
|
||||
user_metadata: { filename: '' } // Invalid empty filename
|
||||
})
|
||||
vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([asset])
|
||||
|
||||
await selectAssetWithCallback(asset.id, onSelectSpy)
|
||||
|
||||
expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456')
|
||||
expect(onSelectSpy).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Invalid asset filename:',
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message: 'Filename cannot be empty'
|
||||
})
|
||||
]),
|
||||
'for asset:',
|
||||
'asset-456'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle API errors gracefully', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
const onSelectSpy = vi.fn()
|
||||
const asset = createApiAsset({ id: 'asset-789' })
|
||||
|
||||
const apiError = new Error('API Error')
|
||||
vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([asset])
|
||||
|
||||
await selectAssetWithCallback(asset.id, onSelectSpy)
|
||||
|
||||
expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789')
|
||||
expect(onSelectSpy).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to fetch asset details for asset-789'),
|
||||
apiError
|
||||
)
|
||||
})
|
||||
|
||||
it('should not fetch details when no callback provided', async () => {
|
||||
const asset = createApiAsset({ id: 'asset-no-callback' })
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([asset])
|
||||
|
||||
await selectAssetWithCallback(asset.id)
|
||||
|
||||
expect(assetService.getAssetDetails).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Filename Validation Security', () => {
|
||||
const createValidationTest = (filename: string) => {
|
||||
const testAsset = createApiAsset({ id: 'validation-test' })
|
||||
const detailAsset = createApiAsset({
|
||||
id: 'validation-test',
|
||||
user_metadata: { filename }
|
||||
})
|
||||
return { testAsset, detailAsset }
|
||||
}
|
||||
|
||||
it('accepts valid file paths with forward slashes', async () => {
|
||||
const onSelectSpy = vi.fn()
|
||||
const { testAsset, detailAsset } = createValidationTest(
|
||||
'models/checkpoints/v1/test-model.safetensors'
|
||||
)
|
||||
vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([testAsset])
|
||||
await selectAssetWithCallback(testAsset.id, onSelectSpy)
|
||||
|
||||
expect(onSelectSpy).toHaveBeenCalledWith(
|
||||
'models/checkpoints/v1/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
|
||||
it('rejects directory traversal attacks', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
const onSelectSpy = vi.fn()
|
||||
|
||||
const maliciousPaths = [
|
||||
'../malicious-model.safetensors',
|
||||
'models/../../../etc/passwd',
|
||||
'/etc/passwd'
|
||||
]
|
||||
|
||||
for (const path of maliciousPaths) {
|
||||
const { testAsset, detailAsset } = createValidationTest(path)
|
||||
vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([testAsset])
|
||||
await selectAssetWithCallback(testAsset.id, onSelectSpy)
|
||||
|
||||
expect(onSelectSpy).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Invalid asset filename:',
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message: 'Path must not start with / or contain ..'
|
||||
})
|
||||
]),
|
||||
'for asset:',
|
||||
'validation-test'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('rejects invalid filename characters', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
const onSelectSpy = vi.fn()
|
||||
|
||||
const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|']
|
||||
|
||||
for (const char of invalidChars) {
|
||||
const { testAsset, detailAsset } = createValidationTest(
|
||||
`bad${char}filename.safetensors`
|
||||
)
|
||||
vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset)
|
||||
|
||||
const { selectAssetWithCallback } = useAssetBrowser([testAsset])
|
||||
await selectAssetWithCallback(testAsset.id, onSelectSpy)
|
||||
|
||||
expect(onSelectSpy).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Invalid asset filename:',
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
message: 'Invalid filename characters'
|
||||
})
|
||||
]),
|
||||
'for asset:',
|
||||
'validation-test'
|
||||
)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dynamic Category Extraction', () => {
|
||||
it('extracts categories from asset tags', () => {
|
||||
const assets = [
|
||||
|
||||
@@ -1,91 +1,100 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
// Mock the dialog store
|
||||
vi.mock('@/stores/dialogStore')
|
||||
|
||||
// Mock the asset service
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetsForNodeType: vi.fn().mockResolvedValue([])
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string, params?: Record<string, string>) => {
|
||||
if (params) {
|
||||
return `${key}:${JSON.stringify(params)}`
|
||||
}
|
||||
return key
|
||||
}
|
||||
}))
|
||||
|
||||
// Test factory functions
|
||||
interface AssetBrowserProps {
|
||||
nodeType: string
|
||||
inputName: string
|
||||
onAssetSelected?: (filename: string) => void
|
||||
}
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetsForNodeType: vi.fn().mockResolvedValue([]),
|
||||
getAssetsByTag: vi.fn().mockResolvedValue([])
|
||||
}
|
||||
}))
|
||||
|
||||
function createAssetBrowserProps(
|
||||
overrides: Partial<AssetBrowserProps> = {}
|
||||
): AssetBrowserProps {
|
||||
const { assetService } = await import('@/platform/assets/services/assetService')
|
||||
const mockGetAssetsByTag = vi.mocked(assetService.getAssetsByTag)
|
||||
const mockGetAssetsForNodeType = vi.mocked(assetService.getAssetsForNodeType)
|
||||
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
inputName: 'ckpt_name',
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function setupDialogMocks() {
|
||||
const mockShowDialog = vi.fn()
|
||||
const mockCloseDialog = vi.fn()
|
||||
vi.mocked(useDialogStore, { partial: true }).mockReturnValue({
|
||||
showDialog: mockShowDialog,
|
||||
closeDialog: mockCloseDialog
|
||||
})
|
||||
|
||||
return { mockShowDialog, mockCloseDialog }
|
||||
}
|
||||
|
||||
describe('useAssetBrowserDialog', () => {
|
||||
describe('Asset Selection Flow', () => {
|
||||
it('auto-closes dialog when asset is selected', async () => {
|
||||
// Create fresh mocks for this test
|
||||
const mockShowDialog = vi.fn()
|
||||
const mockCloseDialog = vi.fn()
|
||||
|
||||
vi.mocked(useDialogStore).mockReturnValue({
|
||||
showDialog: mockShowDialog,
|
||||
closeDialog: mockCloseDialog
|
||||
} as Partial<ReturnType<typeof useDialogStore>> as ReturnType<
|
||||
typeof useDialogStore
|
||||
>)
|
||||
|
||||
const { mockShowDialog, mockCloseDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const onAssetSelected = vi.fn()
|
||||
const props = createAssetBrowserProps({ onAssetSelected })
|
||||
|
||||
await assetBrowserDialog.show(props)
|
||||
await assetBrowserDialog.show({
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
inputName: 'ckpt_name',
|
||||
onAssetSelected
|
||||
})
|
||||
|
||||
// Get the onSelect handler that was passed to the dialog
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
const onSelectHandler = dialogCall.props.onSelect
|
||||
|
||||
// Simulate asset selection
|
||||
onSelectHandler('selected-asset-path')
|
||||
const mockAsset = {
|
||||
id: 'test-asset-id',
|
||||
name: 'test.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: { filename: 'selected-asset-path' }
|
||||
}
|
||||
onSelectHandler(mockAsset)
|
||||
|
||||
// Should call the original callback and trigger hide animation
|
||||
expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path')
|
||||
expect(onAssetSelected).toHaveBeenCalledWith(mockAsset)
|
||||
expect(mockCloseDialog).toHaveBeenCalledWith({
|
||||
key: 'global-asset-browser'
|
||||
})
|
||||
})
|
||||
|
||||
it('closes dialog when close handler is called', async () => {
|
||||
// Create fresh mocks for this test
|
||||
const mockShowDialog = vi.fn()
|
||||
const mockCloseDialog = vi.fn()
|
||||
|
||||
vi.mocked(useDialogStore).mockReturnValue({
|
||||
showDialog: mockShowDialog,
|
||||
closeDialog: mockCloseDialog
|
||||
} as Partial<ReturnType<typeof useDialogStore>> as ReturnType<
|
||||
typeof useDialogStore
|
||||
>)
|
||||
|
||||
const { mockShowDialog, mockCloseDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const props = createAssetBrowserProps()
|
||||
|
||||
await assetBrowserDialog.show(props)
|
||||
await assetBrowserDialog.show({
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
inputName: 'ckpt_name'
|
||||
})
|
||||
|
||||
// Get the onClose handler that was passed to the dialog
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
const onCloseHandler = dialogCall.props.onClose
|
||||
|
||||
// Simulate dialog close
|
||||
onCloseHandler()
|
||||
|
||||
expect(mockCloseDialog).toHaveBeenCalledWith({
|
||||
@@ -93,4 +102,158 @@ describe('useAssetBrowserDialog', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('.browse() method', () => {
|
||||
it('opens asset browser dialog with tag-based filtering', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models',
|
||||
title: 'Model Library'
|
||||
})
|
||||
|
||||
expect(mockShowDialog).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
key: 'global-asset-browser',
|
||||
props: expect.objectContaining({
|
||||
showLeftPanel: true
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('calls onAssetSelected callback when asset is selected', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const mockAsset = createMockAsset()
|
||||
const onAssetSelected = vi.fn()
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models',
|
||||
onAssetSelected
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
const onSelectHandler = dialogCall.props.onSelect
|
||||
|
||||
onSelectHandler(mockAsset)
|
||||
|
||||
expect(onAssetSelected).toHaveBeenCalledWith(mockAsset)
|
||||
})
|
||||
|
||||
it('closes dialog after asset selection', async () => {
|
||||
const { mockShowDialog, mockCloseDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const mockAsset = createMockAsset()
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models'
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
const onSelectHandler = dialogCall.props.onSelect
|
||||
|
||||
onSelectHandler(mockAsset)
|
||||
|
||||
expect(mockCloseDialog).toHaveBeenCalledWith({
|
||||
key: 'global-asset-browser'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses custom title when provided', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models',
|
||||
title: 'Custom Model Browser'
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
expect(dialogCall.props.title).toBe('Custom Model Browser')
|
||||
})
|
||||
|
||||
it('calls getAssetsByTag with correct assetType parameter', async () => {
|
||||
setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models'
|
||||
})
|
||||
|
||||
expect(mockGetAssetsByTag).toHaveBeenCalledWith('models')
|
||||
})
|
||||
|
||||
it('passes fetched assets to dialog props', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const mockAssets = [
|
||||
createMockAsset({ id: 'asset-1', name: 'model1.safetensors' }),
|
||||
createMockAsset({ id: 'asset-2', name: 'model2.safetensors' })
|
||||
]
|
||||
|
||||
mockGetAssetsByTag.mockResolvedValueOnce(mockAssets)
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models'
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
expect(dialogCall.props.assets).toEqual(mockAssets)
|
||||
})
|
||||
|
||||
it('handles asset fetch errors gracefully', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
mockGetAssetsByTag.mockRejectedValueOnce(new Error('Network error'))
|
||||
await assetBrowserDialog.browse({
|
||||
assetType: 'models'
|
||||
})
|
||||
|
||||
expect(mockShowDialog).toHaveBeenCalled()
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
expect(dialogCall.props.assets).toEqual([])
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Failed to fetch assets for tag:',
|
||||
'models',
|
||||
expect.any(Error)
|
||||
)
|
||||
|
||||
consoleErrorSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('.show() title formatting', () => {
|
||||
it('formats title with VAE acronym uppercase', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
mockGetAssetsForNodeType.mockResolvedValueOnce([
|
||||
createMockAsset({ tags: ['models', 'vae'] })
|
||||
])
|
||||
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
await assetBrowserDialog.show({
|
||||
nodeType: 'VAELoader',
|
||||
inputName: 'vae_name'
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
expect(dialogCall.props.title).toContain('VAE')
|
||||
})
|
||||
|
||||
it('replaces underscores with spaces in tag names', async () => {
|
||||
const { mockShowDialog } = setupDialogMocks()
|
||||
mockGetAssetsForNodeType.mockResolvedValueOnce([
|
||||
createMockAsset({ tags: ['models', 'style_models'] })
|
||||
])
|
||||
|
||||
const assetBrowserDialog = useAssetBrowserDialog()
|
||||
await assetBrowserDialog.show({
|
||||
nodeType: 'StyleModelLoader',
|
||||
inputName: 'style_model_name'
|
||||
})
|
||||
|
||||
const dialogCall = mockShowDialog.mock.calls[0][0]
|
||||
expect(dialogCall.props.title).toContain('style models')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,7 +18,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificExtension('pt')
|
||||
]
|
||||
|
||||
const { availableFileFormats } = useAssetFilterOptions(assets)
|
||||
const { availableFileFormats } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableFileFormats.value).toEqual([
|
||||
{ name: '.ckpt', value: 'ckpt' },
|
||||
@@ -34,7 +34,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificExtension('ckpt')
|
||||
]
|
||||
|
||||
const { availableFileFormats } = useAssetFilterOptions(assets)
|
||||
const { availableFileFormats } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableFileFormats.value).toEqual([
|
||||
{ name: '.ckpt', value: 'ckpt' },
|
||||
@@ -48,7 +48,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificExtension('safetensors')
|
||||
]
|
||||
|
||||
const { availableFileFormats } = useAssetFilterOptions(assets)
|
||||
const { availableFileFormats } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableFileFormats.value).toEqual([
|
||||
{ name: '.safetensors', value: 'safetensors' }
|
||||
@@ -56,7 +56,7 @@ describe('useAssetFilterOptions', () => {
|
||||
})
|
||||
|
||||
it('handles empty asset list', () => {
|
||||
const { availableFileFormats } = useAssetFilterOptions([])
|
||||
const { availableFileFormats } = useAssetFilterOptions(() => [])
|
||||
|
||||
expect(availableFileFormats.value).toEqual([])
|
||||
})
|
||||
@@ -70,7 +70,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificBaseModel('sd35')
|
||||
]
|
||||
|
||||
const { availableBaseModels } = useAssetFilterOptions(assets)
|
||||
const { availableBaseModels } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableBaseModels.value).toEqual([
|
||||
{ name: 'sd15', value: 'sd15' },
|
||||
@@ -86,7 +86,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificBaseModel('sdxl')
|
||||
]
|
||||
|
||||
const { availableBaseModels } = useAssetFilterOptions(assets)
|
||||
const { availableBaseModels } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableBaseModels.value).toEqual([
|
||||
{ name: 'sd15', value: 'sd15' },
|
||||
@@ -100,7 +100,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificBaseModel('sd15')
|
||||
]
|
||||
|
||||
const { availableBaseModels } = useAssetFilterOptions(assets)
|
||||
const { availableBaseModels } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableBaseModels.value).toEqual([
|
||||
{ name: 'sd15', value: 'sd15' }
|
||||
@@ -113,7 +113,7 @@ describe('useAssetFilterOptions', () => {
|
||||
createAssetWithSpecificBaseModel('sdxl')
|
||||
]
|
||||
|
||||
const { availableBaseModels } = useAssetFilterOptions(assets)
|
||||
const { availableBaseModels } = useAssetFilterOptions(() => assets)
|
||||
|
||||
expect(availableBaseModels.value).toEqual([
|
||||
{ name: 'sdxl', value: 'sdxl' }
|
||||
@@ -121,7 +121,7 @@ describe('useAssetFilterOptions', () => {
|
||||
})
|
||||
|
||||
it('handles empty asset list', () => {
|
||||
const { availableBaseModels } = useAssetFilterOptions([])
|
||||
const { availableBaseModels } = useAssetFilterOptions(() => [])
|
||||
|
||||
expect(availableBaseModels.value).toEqual([])
|
||||
})
|
||||
@@ -132,9 +132,8 @@ describe('useAssetFilterOptions', () => {
|
||||
const assets = [createAssetWithSpecificExtension('safetensors')]
|
||||
|
||||
const { availableFileFormats, availableBaseModels } =
|
||||
useAssetFilterOptions(assets)
|
||||
useAssetFilterOptions(() => assets)
|
||||
|
||||
// These should be computed refs
|
||||
expect(availableFileFormats.value).toBeDefined()
|
||||
expect(availableBaseModels.value).toBeDefined()
|
||||
expect(typeof availableFileFormats.value).toBe('object')
|
||||
|
||||
403
tests-ui/platform/assets/utils/createModelNodeFromAsset.test.ts
Normal file
403
tests-ui/platform/assets/utils/createModelNodeFromAsset.test.ts
Normal file
@@ -0,0 +1,403 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { type Raw, markRaw } from 'vue'
|
||||
|
||||
import {
|
||||
type LGraphNode,
|
||||
LiteGraph,
|
||||
type Subgraph
|
||||
} from '@/lib/litegraph/src/litegraph'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
apiURL: vi.fn((path: string) => `http://localhost:8188${path}`)
|
||||
}
|
||||
}))
|
||||
vi.mock('@/stores/modelToNodeStore', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@/stores/modelToNodeStore')>()
|
||||
return {
|
||||
...actual,
|
||||
useModelToNodeStore: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock(
|
||||
'@/platform/workflow/management/stores/workflowStore',
|
||||
async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<
|
||||
typeof import('@/platform/workflow/management/stores/workflowStore')
|
||||
>()
|
||||
return {
|
||||
...actual,
|
||||
useWorkflowStore: vi.fn()
|
||||
}
|
||||
}
|
||||
)
|
||||
vi.mock('@/services/litegraphService', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@/services/litegraphService')>()
|
||||
return {
|
||||
...actual,
|
||||
useLitegraphService: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock('@/lib/litegraph/src/litegraph', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@/lib/litegraph/src/litegraph')>()
|
||||
return {
|
||||
...actual,
|
||||
LiteGraph: {
|
||||
...actual.LiteGraph,
|
||||
createNode: vi.fn()
|
||||
}
|
||||
}
|
||||
})
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
canvas: {
|
||||
graph: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
async function createMockNode(overrides?: {
|
||||
widgetName?: string
|
||||
widgetValue?: string
|
||||
hasWidgets?: boolean
|
||||
}): Promise<LGraphNode> {
|
||||
const {
|
||||
widgetName = 'ckpt_name',
|
||||
widgetValue = '',
|
||||
hasWidgets = true
|
||||
} = overrides || {}
|
||||
|
||||
const { LGraphNode: ActualLGraphNode } = await vi.importActual<
|
||||
typeof import('@/lib/litegraph/src/litegraph')
|
||||
>('@/lib/litegraph/src/litegraph')
|
||||
|
||||
if (!hasWidgets) {
|
||||
return Object.create(ActualLGraphNode.prototype)
|
||||
}
|
||||
|
||||
type Widget = NonNullable<LGraphNode['widgets']>[number]
|
||||
const widget: Pick<Widget, 'name' | 'value' | 'type' | 'options' | 'y'> = {
|
||||
name: widgetName,
|
||||
value: widgetValue,
|
||||
type: 'string',
|
||||
options: {},
|
||||
y: 0
|
||||
}
|
||||
|
||||
return Object.create(ActualLGraphNode.prototype, {
|
||||
widgets: { value: [widget], writable: true }
|
||||
})
|
||||
}
|
||||
function createMockNodeProvider() {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint'
|
||||
},
|
||||
key: 'ckpt_name'
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Configures all mocked dependencies with sensible defaults.
|
||||
* Uses semantic parameters for clearer test intent.
|
||||
* For error paths or edge cases, pass null values or specific overrides.
|
||||
*/
|
||||
async function setupMocks(
|
||||
overrides: {
|
||||
nodeProvider?: ReturnType<typeof createMockNodeProvider> | null
|
||||
canvasCenter?: [number, number]
|
||||
activeSubgraph?: Raw<Subgraph>
|
||||
createdNode?: Awaited<ReturnType<typeof createMockNode>> | null
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
nodeProvider = createMockNodeProvider(),
|
||||
canvasCenter = [100, 200],
|
||||
activeSubgraph = undefined,
|
||||
createdNode = await createMockNode()
|
||||
} = overrides
|
||||
|
||||
vi.mocked(useModelToNodeStore).mockReturnValue({
|
||||
...useModelToNodeStore(),
|
||||
getNodeProvider: vi.fn().mockReturnValue(nodeProvider)
|
||||
})
|
||||
|
||||
vi.mocked(useLitegraphService).mockReturnValue({
|
||||
...useLitegraphService(),
|
||||
getCanvasCenter: vi.fn().mockReturnValue(canvasCenter)
|
||||
})
|
||||
|
||||
vi.mocked(useWorkflowStore).mockReturnValue({
|
||||
...useWorkflowStore(),
|
||||
activeSubgraph,
|
||||
isSubgraphActive: !!activeSubgraph
|
||||
})
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(createdNode)
|
||||
}
|
||||
describe('createModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
})
|
||||
describe('when creating nodes from valid assets', () => {
|
||||
it('should create the appropriate loader node for the asset category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(
|
||||
vi.mocked(useModelToNodeStore)().getNodeProvider
|
||||
).toHaveBeenCalledWith('checkpoints')
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [100, 200] }
|
||||
)
|
||||
}
|
||||
})
|
||||
it('should place node at canvas center by default', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({
|
||||
canvasCenter: [150, 250]
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [150, 250] }
|
||||
)
|
||||
})
|
||||
it('should place node at specified position when position is provided', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset, { position: [300, 400] })
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).not.toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [300, 400] }
|
||||
)
|
||||
})
|
||||
it('should populate the loader widget with the asset file path', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
it('should add node to root graph when no subgraph is active', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
it('should add node to active subgraph when present', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
const { Subgraph } = await vi.importActual<
|
||||
typeof import('@/lib/litegraph/src/litegraph')
|
||||
>('@/lib/litegraph/src/litegraph')
|
||||
const mockSubgraph = markRaw(
|
||||
Object.create(Subgraph.prototype, {
|
||||
add: { value: vi.fn() }
|
||||
})
|
||||
)
|
||||
await setupMocks({
|
||||
createdNode: mockNode,
|
||||
activeSubgraph: mockSubgraph
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
describe('when asset data is incomplete or invalid', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it.each([
|
||||
{
|
||||
case: 'missing user_metadata',
|
||||
overrides: { user_metadata: undefined },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /missing required user_metadata/
|
||||
},
|
||||
{
|
||||
case: 'missing filename property',
|
||||
overrides: { user_metadata: {} },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern:
|
||||
/Invalid filename.*expected non-empty string, got undefined/
|
||||
},
|
||||
{
|
||||
case: 'non-string filename',
|
||||
overrides: { user_metadata: { filename: 123 } },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string, got number/
|
||||
},
|
||||
{
|
||||
case: 'empty filename',
|
||||
overrides: { user_metadata: { filename: '' } },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorPattern }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
}
|
||||
)
|
||||
it.each([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorMessage }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toBe(errorMessage)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
describe('when system resources are unavailable', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it('should fail when no provider registered for category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({ nodeProvider: null })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
it('should fail when node creation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(null)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NODE_CREATION_FAILED')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
}
|
||||
})
|
||||
it('should fail when widget is missing from node', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ widgetName: 'wrong_widget' })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
expect(result.error.details?.widgetName).toBe('ckpt_name')
|
||||
}
|
||||
})
|
||||
it('should fail when node has no widgets array', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name not found')
|
||||
}
|
||||
})
|
||||
it('should not add node to graph when widget validation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
createModelNodeFromAsset(asset)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
describe('when graph is null', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(app).canvas.graph = null
|
||||
})
|
||||
it('should fail when no graph is available', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_GRAPH')
|
||||
expect(result.error.message).toBe('No active graph available')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user