From 091e67590b480db0160164030008e660ebe8e93d Mon Sep 17 00:00:00 2001 From: Alexander Brown <448862+DrJKL@users.noreply.github.com> Date: Sun, 1 Feb 2026 12:27:44 -0800 Subject: [PATCH] refactor: extract shared asset filter utilities - Create assetFilterUtils.ts with reusable filter functions - Update useAssetBrowser.ts and WidgetSelectDropdown.vue to use shared utilities - Add comprehensive parameterized tests using it.for Amp-Thread-ID: https://ampcode.com/threads/T-019c18df-419b-7309-ae68-7f05682938d3 Co-authored-by: Amp --- .../assets/composables/useAssetBrowser.ts | 50 +---- .../assets/utils/assetFilterUtils.test.ts | 181 ++++++++++++++++++ src/platform/assets/utils/assetFilterUtils.ts | 66 +++++++ .../components/WidgetSelectDropdown.vue | 29 ++- 4 files changed, 265 insertions(+), 61 deletions(-) create mode 100644 src/platform/assets/utils/assetFilterUtils.test.ts create mode 100644 src/platform/assets/utils/assetFilterUtils.ts diff --git a/src/platform/assets/composables/useAssetBrowser.ts b/src/platform/assets/composables/useAssetBrowser.ts index 9c98be3ca..773a53ffe 100644 --- a/src/platform/assets/composables/useAssetBrowser.ts +++ b/src/platform/assets/composables/useAssetBrowser.ts @@ -10,6 +10,12 @@ import type { OwnershipOption } from '@/platform/assets/types/filterTypes' import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { + filterByBaseModels, + filterByCategory, + filterByFileFormats, + filterByOwnership +} from '@/platform/assets/utils/assetFilterUtils' import { getAssetBaseModels, getAssetFilename @@ -20,50 +26,6 @@ import type { NavGroupData, NavItemData } from '@/types/navTypes' type NavId = 'all' | 'imported' | (string & {}) -function filterByCategory(category: string) { - return (asset: AssetItem) => { - if (category === 'all') return true - - // Check if any tag matches the category (for exact matches) - if (asset.tags.includes(category)) return true - - // Check if any tag's top-level folder matches the category - return asset.tags.some((tag) => { - if (typeof tag === 'string' && tag.includes('/')) { - return tag.split('/')[0] === category - } - return false - }) - } -} - -function filterByFileFormats(formats: string[]) { - return (asset: AssetItem) => { - if (formats.length === 0) return true - const formatSet = new Set(formats) - const extension = asset.name.split('.').pop()?.toLowerCase() - return extension ? formatSet.has(extension) : false - } -} - -function filterByBaseModels(models: string[]) { - return (asset: AssetItem) => { - if (models.length === 0) return true - const modelSet = new Set(models) - const assetBaseModels = getAssetBaseModels(asset) - return assetBaseModels.some((model) => modelSet.has(model)) - } -} - -function filterByOwnership(ownership: OwnershipOption) { - return (asset: AssetItem) => { - if (ownership === 'all') return true - if (ownership === 'my-models') return asset.is_immutable === false - if (ownership === 'public-models') return asset.is_immutable === true - return true - } -} - type AssetBadge = { label: string type: 'type' | 'base' | 'size' diff --git a/src/platform/assets/utils/assetFilterUtils.test.ts b/src/platform/assets/utils/assetFilterUtils.test.ts new file mode 100644 index 000000000..8cf23098c --- /dev/null +++ b/src/platform/assets/utils/assetFilterUtils.test.ts @@ -0,0 +1,181 @@ +import { describe, expect, it } from 'vitest' + +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' + +import { + filterByBaseModels, + filterByCategory, + filterByFileFormats, + filterByOwnership, + filterItemByBaseModels, + filterItemByOwnership +} from './assetFilterUtils' + +function createAsset( + name: string, + options: Partial = {} +): AssetItem { + return { + id: `asset-${name}`, + name, + tags: [], + is_immutable: false, + ...options + } as AssetItem +} + +describe('filterByCategory', () => { + it.for([ + { category: 'all', tags: ['checkpoint'], expected: true }, + { category: 'checkpoint', tags: ['checkpoint'], expected: true }, + { category: 'lora', tags: ['checkpoint'], expected: false }, + { + category: 'checkpoint', + tags: ['models', 'checkpoint/xl'], + expected: true + }, + { category: 'xl', tags: ['models', 'checkpoint/xl'], expected: false } + ])( + 'category=$category with tags=$tags returns $expected', + ({ category, tags, expected }) => { + const filter = filterByCategory(category) + const asset = createAsset('model.safetensors', { tags }) + expect(filter(asset)).toBe(expected) + } + ) +}) + +describe('filterByFileFormats', () => { + it.for([ + { formats: [], name: 'model.safetensors', expected: true }, + { formats: ['safetensors'], name: 'model.safetensors', expected: true }, + { formats: ['ckpt'], name: 'model.safetensors', expected: false }, + { formats: ['safetensors'], name: 'MODEL.SAFETENSORS', expected: true }, + { formats: ['safetensors'], name: 'model', expected: false } + ])( + 'formats=$formats with name=$name returns $expected', + ({ formats, name, expected }) => { + const filter = filterByFileFormats(formats) + const asset = createAsset(name) + expect(filter(asset)).toBe(expected) + } + ) + + it('matches any of multiple formats', () => { + const filter = filterByFileFormats(['safetensors', 'ckpt', 'bin']) + expect(filter(createAsset('model.safetensors'))).toBe(true) + expect(filter(createAsset('model.ckpt'))).toBe(true) + expect(filter(createAsset('model.bin'))).toBe(true) + }) +}) + +describe('filterByBaseModels', () => { + it.for([ + { models: [], expected: true }, + { models: new Set(), expected: true } + ])('empty models ($models) returns true', ({ models }) => { + const filter = filterByBaseModels(models) + const asset = createAsset('model.safetensors') + expect(filter(asset)).toBe(true) + }) + + it.for([ + { + models: ['SDXL'], + metadata: { base_model: ['SDXL'] }, + expected: true + }, + { + models: ['SDXL'], + metadata: { base_model: ['SD1.5'] }, + expected: false + }, + { + models: new Set(['SDXL', 'SD1.5']), + metadata: { base_model: ['SDXL'] }, + expected: true + } + ])( + 'models=$models with metadata.base_model returns $expected', + ({ models, metadata, expected }) => { + const filter = filterByBaseModels(models) + const asset = createAsset('model.safetensors', { metadata }) + expect(filter(asset)).toBe(expected) + } + ) + + it('matches base model in user_metadata', () => { + const filter = filterByBaseModels(['SD1.5']) + const asset = createAsset('model.safetensors', { + user_metadata: { base_model: ['SD1.5'] } + }) + expect(filter(asset)).toBe(true) + }) +}) + +describe('filterByOwnership', () => { + it.for([ + { ownership: 'all' as const, is_immutable: true, expected: true }, + { ownership: 'all' as const, is_immutable: false, expected: true }, + { ownership: 'my-models' as const, is_immutable: false, expected: true }, + { ownership: 'my-models' as const, is_immutable: true, expected: false }, + { + ownership: 'public-models' as const, + is_immutable: true, + expected: true + }, + { + ownership: 'public-models' as const, + is_immutable: false, + expected: false + } + ])( + 'ownership=$ownership with is_immutable=$is_immutable returns $expected', + ({ ownership, is_immutable, expected }) => { + const filter = filterByOwnership(ownership) + const asset = createAsset('model', { is_immutable }) + expect(filter(asset)).toBe(expected) + } + ) +}) + +describe('filterItemByOwnership', () => { + const items = [ + { id: '1', is_immutable: true }, + { id: '2', is_immutable: false }, + { id: '3', is_immutable: true } + ] + + it.for([ + { ownership: 'all' as const, expectedIds: ['1', '2', '3'] }, + { ownership: 'my-models' as const, expectedIds: ['2'] }, + { ownership: 'public-models' as const, expectedIds: ['1', '3'] } + ])( + 'ownership=$ownership returns items with ids=$expectedIds', + ({ ownership, expectedIds }) => { + const result = filterItemByOwnership(items, ownership) + expect(result.map((i) => i.id)).toEqual(expectedIds) + } + ) +}) + +describe('filterItemByBaseModels', () => { + const items = [ + { id: '1', base_models: ['SDXL'] }, + { id: '2', base_models: ['SD1.5'] }, + { id: '3', base_models: ['SDXL', 'SD1.5'] }, + { id: '4' } + ] + + it.for([ + { selectedModels: new Set(), expectedIds: ['1', '2', '3', '4'] }, + { selectedModels: new Set(['SDXL']), expectedIds: ['1', '3'] }, + { selectedModels: new Set(['SD1.5']), expectedIds: ['2', '3'] } + ])( + 'selectedModels=$selectedModels returns items with ids=$expectedIds', + ({ selectedModels, expectedIds }) => { + const result = filterItemByBaseModels(items, selectedModels) + expect(result.map((i) => i.id)).toEqual(expectedIds) + } + ) +}) diff --git a/src/platform/assets/utils/assetFilterUtils.ts b/src/platform/assets/utils/assetFilterUtils.ts new file mode 100644 index 000000000..cdcc1946b --- /dev/null +++ b/src/platform/assets/utils/assetFilterUtils.ts @@ -0,0 +1,66 @@ +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import type { OwnershipOption } from '@/platform/assets/types/filterTypes' +import { getAssetBaseModels } from '@/platform/assets/utils/assetMetadataUtils' + +export function filterByCategory(category: string) { + return (asset: AssetItem) => { + if (category === 'all') return true + + // Check if any tag matches the category (for exact matches) + if (asset.tags.includes(category)) return true + + // Check if any tag's top-level folder matches the category + return asset.tags.some((tag) => { + if (typeof tag === 'string' && tag.includes('/')) { + return tag.split('/')[0] === category + } + return false + }) + } +} + +export function filterByFileFormats(formats: string[]) { + return (asset: AssetItem) => { + if (formats.length === 0) return true + const formatSet = new Set(formats) + const extension = asset.name.split('.').pop()?.toLowerCase() + return extension ? formatSet.has(extension) : false + } +} + +export function filterByBaseModels(models: string[] | Set) { + const modelSet = models instanceof Set ? models : new Set(models) + return (asset: AssetItem) => { + if (modelSet.size === 0) return true + const assetBaseModels = getAssetBaseModels(asset) + return assetBaseModels.some((model) => modelSet.has(model)) + } +} + +export function filterByOwnership(ownership: OwnershipOption) { + return (asset: AssetItem) => { + if (ownership === 'all') return true + if (ownership === 'my-models') return asset.is_immutable === false + if (ownership === 'public-models') return asset.is_immutable === true + return true + } +} + +export function filterItemByOwnership( + items: T[], + ownership: OwnershipOption +): T[] { + if (ownership === 'all') return items + const isPublic = ownership === 'public-models' + return items.filter((item) => item.is_immutable === isPublic) +} + +export function filterItemByBaseModels( + items: T[], + selectedModels: Set +): T[] { + if (selectedModels.size === 0) return items + return items.filter((item) => + item.base_models?.some((model) => selectedModels.has(model)) + ) +} diff --git a/src/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue b/src/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue index 020fa508c..5062db958 100644 --- a/src/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue +++ b/src/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue @@ -4,6 +4,10 @@ import { computed, provide, ref, toRef, watch } from 'vue' import { useTransformCompatOverlayProps } from '@/composables/useTransformCompatOverlayProps' import { t } from '@/i18n' +import { + filterItemByBaseModels, + filterItemByOwnership +} from '@/platform/assets/utils/assetFilterUtils' import { getAssetBaseModels, getAssetDisplayName, @@ -242,25 +246,16 @@ const assetItems = computed(() => { })) }) -/** - * Filters asset items by ownership selection. - */ -const ownershipFilteredAssetItems = computed(() => { - if (ownershipSelected.value === 'all') return assetItems.value - const isPublic = ownershipSelected.value === 'public-models' - return assetItems.value.filter((item) => item.is_immutable === isPublic) -}) +const ownershipFilteredAssetItems = computed(() => + filterItemByOwnership(assetItems.value, ownershipSelected.value) +) -/** - * Filters asset items by base model selection. - */ -const baseModelFilteredAssetItems = computed(() => { - if (baseModelSelected.value.size === 0) - return ownershipFilteredAssetItems.value - return ownershipFilteredAssetItems.value.filter((item) => - item.base_models?.some((model) => baseModelSelected.value.has(model)) +const baseModelFilteredAssetItems = computed(() => + filterItemByBaseModels( + ownershipFilteredAssetItems.value, + baseModelSelected.value ) -}) +) const allItems = computed(() => { if (props.isAssetMode && assetData) {