mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-09 01:20:09 +00:00
refactor: extract shared asset filter utilities
- Create assetFilterUtils.ts with reusable filter functions - Update useAssetBrowser.ts and WidgetSelectDropdown.vue to use shared utilities - Add comprehensive parameterized tests using it.for Amp-Thread-ID: https://ampcode.com/threads/T-019c18df-419b-7309-ae68-7f05682938d3 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -10,6 +10,12 @@ import type {
|
||||
OwnershipOption
|
||||
} from '@/platform/assets/types/filterTypes'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
filterByBaseModels,
|
||||
filterByCategory,
|
||||
filterByFileFormats,
|
||||
filterByOwnership
|
||||
} from '@/platform/assets/utils/assetFilterUtils'
|
||||
import {
|
||||
getAssetBaseModels,
|
||||
getAssetFilename
|
||||
@@ -20,50 +26,6 @@ import type { NavGroupData, NavItemData } from '@/types/navTypes'
|
||||
|
||||
type NavId = 'all' | 'imported' | (string & {})
|
||||
|
||||
function filterByCategory(category: string) {
|
||||
return (asset: AssetItem) => {
|
||||
if (category === 'all') return true
|
||||
|
||||
// Check if any tag matches the category (for exact matches)
|
||||
if (asset.tags.includes(category)) return true
|
||||
|
||||
// Check if any tag's top-level folder matches the category
|
||||
return asset.tags.some((tag) => {
|
||||
if (typeof tag === 'string' && tag.includes('/')) {
|
||||
return tag.split('/')[0] === category
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function filterByFileFormats(formats: string[]) {
|
||||
return (asset: AssetItem) => {
|
||||
if (formats.length === 0) return true
|
||||
const formatSet = new Set(formats)
|
||||
const extension = asset.name.split('.').pop()?.toLowerCase()
|
||||
return extension ? formatSet.has(extension) : false
|
||||
}
|
||||
}
|
||||
|
||||
function filterByBaseModels(models: string[]) {
|
||||
return (asset: AssetItem) => {
|
||||
if (models.length === 0) return true
|
||||
const modelSet = new Set(models)
|
||||
const assetBaseModels = getAssetBaseModels(asset)
|
||||
return assetBaseModels.some((model) => modelSet.has(model))
|
||||
}
|
||||
}
|
||||
|
||||
function filterByOwnership(ownership: OwnershipOption) {
|
||||
return (asset: AssetItem) => {
|
||||
if (ownership === 'all') return true
|
||||
if (ownership === 'my-models') return asset.is_immutable === false
|
||||
if (ownership === 'public-models') return asset.is_immutable === true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
type AssetBadge = {
|
||||
label: string
|
||||
type: 'type' | 'base' | 'size'
|
||||
|
||||
181
src/platform/assets/utils/assetFilterUtils.test.ts
Normal file
181
src/platform/assets/utils/assetFilterUtils.test.ts
Normal file
@@ -0,0 +1,181 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
|
||||
import {
|
||||
filterByBaseModels,
|
||||
filterByCategory,
|
||||
filterByFileFormats,
|
||||
filterByOwnership,
|
||||
filterItemByBaseModels,
|
||||
filterItemByOwnership
|
||||
} from './assetFilterUtils'
|
||||
|
||||
function createAsset(
|
||||
name: string,
|
||||
options: Partial<AssetItem> = {}
|
||||
): AssetItem {
|
||||
return {
|
||||
id: `asset-${name}`,
|
||||
name,
|
||||
tags: [],
|
||||
is_immutable: false,
|
||||
...options
|
||||
} as AssetItem
|
||||
}
|
||||
|
||||
describe('filterByCategory', () => {
|
||||
it.for([
|
||||
{ category: 'all', tags: ['checkpoint'], expected: true },
|
||||
{ category: 'checkpoint', tags: ['checkpoint'], expected: true },
|
||||
{ category: 'lora', tags: ['checkpoint'], expected: false },
|
||||
{
|
||||
category: 'checkpoint',
|
||||
tags: ['models', 'checkpoint/xl'],
|
||||
expected: true
|
||||
},
|
||||
{ category: 'xl', tags: ['models', 'checkpoint/xl'], expected: false }
|
||||
])(
|
||||
'category=$category with tags=$tags returns $expected',
|
||||
({ category, tags, expected }) => {
|
||||
const filter = filterByCategory(category)
|
||||
const asset = createAsset('model.safetensors', { tags })
|
||||
expect(filter(asset)).toBe(expected)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
describe('filterByFileFormats', () => {
|
||||
it.for([
|
||||
{ formats: [], name: 'model.safetensors', expected: true },
|
||||
{ formats: ['safetensors'], name: 'model.safetensors', expected: true },
|
||||
{ formats: ['ckpt'], name: 'model.safetensors', expected: false },
|
||||
{ formats: ['safetensors'], name: 'MODEL.SAFETENSORS', expected: true },
|
||||
{ formats: ['safetensors'], name: 'model', expected: false }
|
||||
])(
|
||||
'formats=$formats with name=$name returns $expected',
|
||||
({ formats, name, expected }) => {
|
||||
const filter = filterByFileFormats(formats)
|
||||
const asset = createAsset(name)
|
||||
expect(filter(asset)).toBe(expected)
|
||||
}
|
||||
)
|
||||
|
||||
it('matches any of multiple formats', () => {
|
||||
const filter = filterByFileFormats(['safetensors', 'ckpt', 'bin'])
|
||||
expect(filter(createAsset('model.safetensors'))).toBe(true)
|
||||
expect(filter(createAsset('model.ckpt'))).toBe(true)
|
||||
expect(filter(createAsset('model.bin'))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('filterByBaseModels', () => {
|
||||
it.for([
|
||||
{ models: [], expected: true },
|
||||
{ models: new Set<string>(), expected: true }
|
||||
])('empty models ($models) returns true', ({ models }) => {
|
||||
const filter = filterByBaseModels(models)
|
||||
const asset = createAsset('model.safetensors')
|
||||
expect(filter(asset)).toBe(true)
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
models: ['SDXL'],
|
||||
metadata: { base_model: ['SDXL'] },
|
||||
expected: true
|
||||
},
|
||||
{
|
||||
models: ['SDXL'],
|
||||
metadata: { base_model: ['SD1.5'] },
|
||||
expected: false
|
||||
},
|
||||
{
|
||||
models: new Set(['SDXL', 'SD1.5']),
|
||||
metadata: { base_model: ['SDXL'] },
|
||||
expected: true
|
||||
}
|
||||
])(
|
||||
'models=$models with metadata.base_model returns $expected',
|
||||
({ models, metadata, expected }) => {
|
||||
const filter = filterByBaseModels(models)
|
||||
const asset = createAsset('model.safetensors', { metadata })
|
||||
expect(filter(asset)).toBe(expected)
|
||||
}
|
||||
)
|
||||
|
||||
it('matches base model in user_metadata', () => {
|
||||
const filter = filterByBaseModels(['SD1.5'])
|
||||
const asset = createAsset('model.safetensors', {
|
||||
user_metadata: { base_model: ['SD1.5'] }
|
||||
})
|
||||
expect(filter(asset)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('filterByOwnership', () => {
|
||||
it.for([
|
||||
{ ownership: 'all' as const, is_immutable: true, expected: true },
|
||||
{ ownership: 'all' as const, is_immutable: false, expected: true },
|
||||
{ ownership: 'my-models' as const, is_immutable: false, expected: true },
|
||||
{ ownership: 'my-models' as const, is_immutable: true, expected: false },
|
||||
{
|
||||
ownership: 'public-models' as const,
|
||||
is_immutable: true,
|
||||
expected: true
|
||||
},
|
||||
{
|
||||
ownership: 'public-models' as const,
|
||||
is_immutable: false,
|
||||
expected: false
|
||||
}
|
||||
])(
|
||||
'ownership=$ownership with is_immutable=$is_immutable returns $expected',
|
||||
({ ownership, is_immutable, expected }) => {
|
||||
const filter = filterByOwnership(ownership)
|
||||
const asset = createAsset('model', { is_immutable })
|
||||
expect(filter(asset)).toBe(expected)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
describe('filterItemByOwnership', () => {
|
||||
const items = [
|
||||
{ id: '1', is_immutable: true },
|
||||
{ id: '2', is_immutable: false },
|
||||
{ id: '3', is_immutable: true }
|
||||
]
|
||||
|
||||
it.for([
|
||||
{ ownership: 'all' as const, expectedIds: ['1', '2', '3'] },
|
||||
{ ownership: 'my-models' as const, expectedIds: ['2'] },
|
||||
{ ownership: 'public-models' as const, expectedIds: ['1', '3'] }
|
||||
])(
|
||||
'ownership=$ownership returns items with ids=$expectedIds',
|
||||
({ ownership, expectedIds }) => {
|
||||
const result = filterItemByOwnership(items, ownership)
|
||||
expect(result.map((i) => i.id)).toEqual(expectedIds)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
describe('filterItemByBaseModels', () => {
|
||||
const items = [
|
||||
{ id: '1', base_models: ['SDXL'] },
|
||||
{ id: '2', base_models: ['SD1.5'] },
|
||||
{ id: '3', base_models: ['SDXL', 'SD1.5'] },
|
||||
{ id: '4' }
|
||||
]
|
||||
|
||||
it.for([
|
||||
{ selectedModels: new Set<string>(), expectedIds: ['1', '2', '3', '4'] },
|
||||
{ selectedModels: new Set(['SDXL']), expectedIds: ['1', '3'] },
|
||||
{ selectedModels: new Set(['SD1.5']), expectedIds: ['2', '3'] }
|
||||
])(
|
||||
'selectedModels=$selectedModels returns items with ids=$expectedIds',
|
||||
({ selectedModels, expectedIds }) => {
|
||||
const result = filterItemByBaseModels(items, selectedModels)
|
||||
expect(result.map((i) => i.id)).toEqual(expectedIds)
|
||||
}
|
||||
)
|
||||
})
|
||||
66
src/platform/assets/utils/assetFilterUtils.ts
Normal file
66
src/platform/assets/utils/assetFilterUtils.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { OwnershipOption } from '@/platform/assets/types/filterTypes'
|
||||
import { getAssetBaseModels } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
|
||||
export function filterByCategory(category: string) {
|
||||
return (asset: AssetItem) => {
|
||||
if (category === 'all') return true
|
||||
|
||||
// Check if any tag matches the category (for exact matches)
|
||||
if (asset.tags.includes(category)) return true
|
||||
|
||||
// Check if any tag's top-level folder matches the category
|
||||
return asset.tags.some((tag) => {
|
||||
if (typeof tag === 'string' && tag.includes('/')) {
|
||||
return tag.split('/')[0] === category
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export function filterByFileFormats(formats: string[]) {
|
||||
return (asset: AssetItem) => {
|
||||
if (formats.length === 0) return true
|
||||
const formatSet = new Set(formats)
|
||||
const extension = asset.name.split('.').pop()?.toLowerCase()
|
||||
return extension ? formatSet.has(extension) : false
|
||||
}
|
||||
}
|
||||
|
||||
export function filterByBaseModels(models: string[] | Set<string>) {
|
||||
const modelSet = models instanceof Set ? models : new Set(models)
|
||||
return (asset: AssetItem) => {
|
||||
if (modelSet.size === 0) return true
|
||||
const assetBaseModels = getAssetBaseModels(asset)
|
||||
return assetBaseModels.some((model) => modelSet.has(model))
|
||||
}
|
||||
}
|
||||
|
||||
export function filterByOwnership(ownership: OwnershipOption) {
|
||||
return (asset: AssetItem) => {
|
||||
if (ownership === 'all') return true
|
||||
if (ownership === 'my-models') return asset.is_immutable === false
|
||||
if (ownership === 'public-models') return asset.is_immutable === true
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
export function filterItemByOwnership<T extends { is_immutable?: boolean }>(
|
||||
items: T[],
|
||||
ownership: OwnershipOption
|
||||
): T[] {
|
||||
if (ownership === 'all') return items
|
||||
const isPublic = ownership === 'public-models'
|
||||
return items.filter((item) => item.is_immutable === isPublic)
|
||||
}
|
||||
|
||||
export function filterItemByBaseModels<T extends { base_models?: string[] }>(
|
||||
items: T[],
|
||||
selectedModels: Set<string>
|
||||
): T[] {
|
||||
if (selectedModels.size === 0) return items
|
||||
return items.filter((item) =>
|
||||
item.base_models?.some((model) => selectedModels.has(model))
|
||||
)
|
||||
}
|
||||
@@ -4,6 +4,10 @@ import { computed, provide, ref, toRef, watch } from 'vue'
|
||||
|
||||
import { useTransformCompatOverlayProps } from '@/composables/useTransformCompatOverlayProps'
|
||||
import { t } from '@/i18n'
|
||||
import {
|
||||
filterItemByBaseModels,
|
||||
filterItemByOwnership
|
||||
} from '@/platform/assets/utils/assetFilterUtils'
|
||||
import {
|
||||
getAssetBaseModels,
|
||||
getAssetDisplayName,
|
||||
@@ -242,25 +246,16 @@ const assetItems = computed<FormDropdownItem[]>(() => {
|
||||
}))
|
||||
})
|
||||
|
||||
/**
|
||||
* Filters asset items by ownership selection.
|
||||
*/
|
||||
const ownershipFilteredAssetItems = computed<FormDropdownItem[]>(() => {
|
||||
if (ownershipSelected.value === 'all') return assetItems.value
|
||||
const isPublic = ownershipSelected.value === 'public-models'
|
||||
return assetItems.value.filter((item) => item.is_immutable === isPublic)
|
||||
})
|
||||
const ownershipFilteredAssetItems = computed<FormDropdownItem[]>(() =>
|
||||
filterItemByOwnership(assetItems.value, ownershipSelected.value)
|
||||
)
|
||||
|
||||
/**
|
||||
* Filters asset items by base model selection.
|
||||
*/
|
||||
const baseModelFilteredAssetItems = computed<FormDropdownItem[]>(() => {
|
||||
if (baseModelSelected.value.size === 0)
|
||||
return ownershipFilteredAssetItems.value
|
||||
return ownershipFilteredAssetItems.value.filter((item) =>
|
||||
item.base_models?.some((model) => baseModelSelected.value.has(model))
|
||||
const baseModelFilteredAssetItems = computed<FormDropdownItem[]>(() =>
|
||||
filterItemByBaseModels(
|
||||
ownershipFilteredAssetItems.value,
|
||||
baseModelSelected.value
|
||||
)
|
||||
})
|
||||
)
|
||||
|
||||
const allItems = computed<FormDropdownItem[]>(() => {
|
||||
if (props.isAssetMode && assetData) {
|
||||
|
||||
Reference in New Issue
Block a user