refactor: extract shared asset filter utilities

- Create assetFilterUtils.ts with reusable filter functions

- Update useAssetBrowser.ts and WidgetSelectDropdown.vue to use shared utilities

- Add comprehensive parameterized tests using it.for

Amp-Thread-ID: https://ampcode.com/threads/T-019c18df-419b-7309-ae68-7f05682938d3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Alexander Brown
2026-02-01 12:27:44 -08:00
parent aa1aef17de
commit 091e67590b
4 changed files with 265 additions and 61 deletions

View File

@@ -10,6 +10,12 @@ import type {
OwnershipOption
} from '@/platform/assets/types/filterTypes'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import {
filterByBaseModels,
filterByCategory,
filterByFileFormats,
filterByOwnership
} from '@/platform/assets/utils/assetFilterUtils'
import {
getAssetBaseModels,
getAssetFilename
@@ -20,50 +26,6 @@ import type { NavGroupData, NavItemData } from '@/types/navTypes'
type NavId = 'all' | 'imported' | (string & {})
function filterByCategory(category: string) {
return (asset: AssetItem) => {
if (category === 'all') return true
// Check if any tag matches the category (for exact matches)
if (asset.tags.includes(category)) return true
// Check if any tag's top-level folder matches the category
return asset.tags.some((tag) => {
if (typeof tag === 'string' && tag.includes('/')) {
return tag.split('/')[0] === category
}
return false
})
}
}
function filterByFileFormats(formats: string[]) {
return (asset: AssetItem) => {
if (formats.length === 0) return true
const formatSet = new Set(formats)
const extension = asset.name.split('.').pop()?.toLowerCase()
return extension ? formatSet.has(extension) : false
}
}
function filterByBaseModels(models: string[]) {
return (asset: AssetItem) => {
if (models.length === 0) return true
const modelSet = new Set(models)
const assetBaseModels = getAssetBaseModels(asset)
return assetBaseModels.some((model) => modelSet.has(model))
}
}
function filterByOwnership(ownership: OwnershipOption) {
return (asset: AssetItem) => {
if (ownership === 'all') return true
if (ownership === 'my-models') return asset.is_immutable === false
if (ownership === 'public-models') return asset.is_immutable === true
return true
}
}
type AssetBadge = {
label: string
type: 'type' | 'base' | 'size'

View File

@@ -0,0 +1,181 @@
import { describe, expect, it } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import {
filterByBaseModels,
filterByCategory,
filterByFileFormats,
filterByOwnership,
filterItemByBaseModels,
filterItemByOwnership
} from './assetFilterUtils'
function createAsset(
name: string,
options: Partial<AssetItem> = {}
): AssetItem {
return {
id: `asset-${name}`,
name,
tags: [],
is_immutable: false,
...options
} as AssetItem
}
describe('filterByCategory', () => {
it.for([
{ category: 'all', tags: ['checkpoint'], expected: true },
{ category: 'checkpoint', tags: ['checkpoint'], expected: true },
{ category: 'lora', tags: ['checkpoint'], expected: false },
{
category: 'checkpoint',
tags: ['models', 'checkpoint/xl'],
expected: true
},
{ category: 'xl', tags: ['models', 'checkpoint/xl'], expected: false }
])(
'category=$category with tags=$tags returns $expected',
({ category, tags, expected }) => {
const filter = filterByCategory(category)
const asset = createAsset('model.safetensors', { tags })
expect(filter(asset)).toBe(expected)
}
)
})
describe('filterByFileFormats', () => {
it.for([
{ formats: [], name: 'model.safetensors', expected: true },
{ formats: ['safetensors'], name: 'model.safetensors', expected: true },
{ formats: ['ckpt'], name: 'model.safetensors', expected: false },
{ formats: ['safetensors'], name: 'MODEL.SAFETENSORS', expected: true },
{ formats: ['safetensors'], name: 'model', expected: false }
])(
'formats=$formats with name=$name returns $expected',
({ formats, name, expected }) => {
const filter = filterByFileFormats(formats)
const asset = createAsset(name)
expect(filter(asset)).toBe(expected)
}
)
it('matches any of multiple formats', () => {
const filter = filterByFileFormats(['safetensors', 'ckpt', 'bin'])
expect(filter(createAsset('model.safetensors'))).toBe(true)
expect(filter(createAsset('model.ckpt'))).toBe(true)
expect(filter(createAsset('model.bin'))).toBe(true)
})
})
describe('filterByBaseModels', () => {
it.for([
{ models: [], expected: true },
{ models: new Set<string>(), expected: true }
])('empty models ($models) returns true', ({ models }) => {
const filter = filterByBaseModels(models)
const asset = createAsset('model.safetensors')
expect(filter(asset)).toBe(true)
})
it.for([
{
models: ['SDXL'],
metadata: { base_model: ['SDXL'] },
expected: true
},
{
models: ['SDXL'],
metadata: { base_model: ['SD1.5'] },
expected: false
},
{
models: new Set(['SDXL', 'SD1.5']),
metadata: { base_model: ['SDXL'] },
expected: true
}
])(
'models=$models with metadata.base_model returns $expected',
({ models, metadata, expected }) => {
const filter = filterByBaseModels(models)
const asset = createAsset('model.safetensors', { metadata })
expect(filter(asset)).toBe(expected)
}
)
it('matches base model in user_metadata', () => {
const filter = filterByBaseModels(['SD1.5'])
const asset = createAsset('model.safetensors', {
user_metadata: { base_model: ['SD1.5'] }
})
expect(filter(asset)).toBe(true)
})
})
describe('filterByOwnership', () => {
it.for([
{ ownership: 'all' as const, is_immutable: true, expected: true },
{ ownership: 'all' as const, is_immutable: false, expected: true },
{ ownership: 'my-models' as const, is_immutable: false, expected: true },
{ ownership: 'my-models' as const, is_immutable: true, expected: false },
{
ownership: 'public-models' as const,
is_immutable: true,
expected: true
},
{
ownership: 'public-models' as const,
is_immutable: false,
expected: false
}
])(
'ownership=$ownership with is_immutable=$is_immutable returns $expected',
({ ownership, is_immutable, expected }) => {
const filter = filterByOwnership(ownership)
const asset = createAsset('model', { is_immutable })
expect(filter(asset)).toBe(expected)
}
)
})
describe('filterItemByOwnership', () => {
const items = [
{ id: '1', is_immutable: true },
{ id: '2', is_immutable: false },
{ id: '3', is_immutable: true }
]
it.for([
{ ownership: 'all' as const, expectedIds: ['1', '2', '3'] },
{ ownership: 'my-models' as const, expectedIds: ['2'] },
{ ownership: 'public-models' as const, expectedIds: ['1', '3'] }
])(
'ownership=$ownership returns items with ids=$expectedIds',
({ ownership, expectedIds }) => {
const result = filterItemByOwnership(items, ownership)
expect(result.map((i) => i.id)).toEqual(expectedIds)
}
)
})
describe('filterItemByBaseModels', () => {
const items = [
{ id: '1', base_models: ['SDXL'] },
{ id: '2', base_models: ['SD1.5'] },
{ id: '3', base_models: ['SDXL', 'SD1.5'] },
{ id: '4' }
]
it.for([
{ selectedModels: new Set<string>(), expectedIds: ['1', '2', '3', '4'] },
{ selectedModels: new Set(['SDXL']), expectedIds: ['1', '3'] },
{ selectedModels: new Set(['SD1.5']), expectedIds: ['2', '3'] }
])(
'selectedModels=$selectedModels returns items with ids=$expectedIds',
({ selectedModels, expectedIds }) => {
const result = filterItemByBaseModels(items, selectedModels)
expect(result.map((i) => i.id)).toEqual(expectedIds)
}
)
})

View File

@@ -0,0 +1,66 @@
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import type { OwnershipOption } from '@/platform/assets/types/filterTypes'
import { getAssetBaseModels } from '@/platform/assets/utils/assetMetadataUtils'
export function filterByCategory(category: string) {
return (asset: AssetItem) => {
if (category === 'all') return true
// Check if any tag matches the category (for exact matches)
if (asset.tags.includes(category)) return true
// Check if any tag's top-level folder matches the category
return asset.tags.some((tag) => {
if (typeof tag === 'string' && tag.includes('/')) {
return tag.split('/')[0] === category
}
return false
})
}
}
export function filterByFileFormats(formats: string[]) {
return (asset: AssetItem) => {
if (formats.length === 0) return true
const formatSet = new Set(formats)
const extension = asset.name.split('.').pop()?.toLowerCase()
return extension ? formatSet.has(extension) : false
}
}
export function filterByBaseModels(models: string[] | Set<string>) {
const modelSet = models instanceof Set ? models : new Set(models)
return (asset: AssetItem) => {
if (modelSet.size === 0) return true
const assetBaseModels = getAssetBaseModels(asset)
return assetBaseModels.some((model) => modelSet.has(model))
}
}
export function filterByOwnership(ownership: OwnershipOption) {
return (asset: AssetItem) => {
if (ownership === 'all') return true
if (ownership === 'my-models') return asset.is_immutable === false
if (ownership === 'public-models') return asset.is_immutable === true
return true
}
}
export function filterItemByOwnership<T extends { is_immutable?: boolean }>(
items: T[],
ownership: OwnershipOption
): T[] {
if (ownership === 'all') return items
const isPublic = ownership === 'public-models'
return items.filter((item) => item.is_immutable === isPublic)
}
export function filterItemByBaseModels<T extends { base_models?: string[] }>(
items: T[],
selectedModels: Set<string>
): T[] {
if (selectedModels.size === 0) return items
return items.filter((item) =>
item.base_models?.some((model) => selectedModels.has(model))
)
}

View File

@@ -4,6 +4,10 @@ import { computed, provide, ref, toRef, watch } from 'vue'
import { useTransformCompatOverlayProps } from '@/composables/useTransformCompatOverlayProps'
import { t } from '@/i18n'
import {
filterItemByBaseModels,
filterItemByOwnership
} from '@/platform/assets/utils/assetFilterUtils'
import {
getAssetBaseModels,
getAssetDisplayName,
@@ -242,25 +246,16 @@ const assetItems = computed<FormDropdownItem[]>(() => {
}))
})
/**
* Filters asset items by ownership selection.
*/
const ownershipFilteredAssetItems = computed<FormDropdownItem[]>(() => {
if (ownershipSelected.value === 'all') return assetItems.value
const isPublic = ownershipSelected.value === 'public-models'
return assetItems.value.filter((item) => item.is_immutable === isPublic)
})
const ownershipFilteredAssetItems = computed<FormDropdownItem[]>(() =>
filterItemByOwnership(assetItems.value, ownershipSelected.value)
)
/**
* Filters asset items by base model selection.
*/
const baseModelFilteredAssetItems = computed<FormDropdownItem[]>(() => {
if (baseModelSelected.value.size === 0)
return ownershipFilteredAssetItems.value
return ownershipFilteredAssetItems.value.filter((item) =>
item.base_models?.some((model) => baseModelSelected.value.has(model))
const baseModelFilteredAssetItems = computed<FormDropdownItem[]>(() =>
filterItemByBaseModels(
ownershipFilteredAssetItems.value,
baseModelSelected.value
)
})
)
const allItems = computed<FormDropdownItem[]>(() => {
if (props.isAssetMode && assetData) {