diff --git a/src/composables/useTemplateFiltering.ts b/src/composables/useTemplateFiltering.ts index c22abb57bf..03a599f292 100644 --- a/src/composables/useTemplateFiltering.ts +++ b/src/composables/useTemplateFiltering.ts @@ -128,6 +128,17 @@ export function useTemplateFiltering( }) }) + const getVramMetric = (template: TemplateInfo) => { + if ( + typeof template.vram === 'number' && + Number.isFinite(template.vram) && + template.vram > 0 + ) { + return template.vram + } + return Number.POSITIVE_INFINITY + } + const sortedTemplates = computed(() => { const templates = [...filteredByLicenses.value] @@ -145,9 +156,21 @@ export function useTemplateFiltering( return dateB.getTime() - dateA.getTime() }) case 'vram-low-to-high': - // TODO: Implement VRAM sorting when VRAM data is available - // For now, keep original order - return templates + return templates.sort((a, b) => { + const vramA = getVramMetric(a) + const vramB = getVramMetric(b) + + if (vramA === vramB) { + const nameA = a.title || a.name || '' + const nameB = b.title || b.name || '' + return nameA.localeCompare(nameB) + } + + if (vramA === Number.POSITIVE_INFINITY) return 1 + if (vramB === Number.POSITIVE_INFINITY) return -1 + + return vramA - vramB + }) case 'model-size-low-to-high': return templates.sort((a: any, b: any) => { const sizeA = diff --git a/src/platform/workflow/templates/types/template.ts b/src/platform/workflow/templates/types/template.ts index dba93adf44..774f3a5857 100644 --- a/src/platform/workflow/templates/types/template.ts +++ b/src/platform/workflow/templates/types/template.ts @@ -18,6 +18,10 @@ export interface TemplateInfo { date?: string useCase?: string license?: string + /** + * Estimated VRAM requirement in bytes. + */ + vram?: number size?: number } diff --git a/tests-ui/tests/composables/useTemplateFiltering.test.ts b/tests-ui/tests/composables/useTemplateFiltering.test.ts new file mode 100644 index 0000000000..8fbdf59fd6 --- /dev/null +++ b/tests-ui/tests/composables/useTemplateFiltering.test.ts @@ -0,0 +1,231 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { nextTick, ref } from 'vue' + +import { useTemplateFiltering } from '@/composables/useTemplateFiltering' +import type { TemplateInfo } from '@/platform/workflow/templates/types/template' + +describe('useTemplateFiltering', () => { + afterEach(() => { + vi.useRealTimers() + }) + + it('sorts templates by VRAM from low to high and pushes missing values last', () => { + const gb = (value: number) => value * 1024 ** 3 + + const templates = ref([ + { + name: 'missing-vram', + description: 'no vram value', + mediaType: 'image', + mediaSubtype: 'png' + }, + { + name: 'highest-vram', + description: 'high usage', + mediaType: 'image', + mediaSubtype: 'png', + vram: gb(12) + }, + { + name: 'mid-vram', + description: 'medium usage', + mediaType: 'image', + mediaSubtype: 'png', + vram: gb(7.5) + }, + { + name: 'low-vram', + description: 'low usage', + mediaType: 'image', + mediaSubtype: 'png', + vram: gb(5) + }, + { + name: 'zero-vram', + description: 'unknown usage', + mediaType: 'image', + mediaSubtype: 'png', + vram: 0 + } + ]) + + const { sortBy, filteredTemplates } = useTemplateFiltering(templates) + + sortBy.value = 'vram-low-to-high' + + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'low-vram', + 'mid-vram', + 'highest-vram', + 'missing-vram', + 'zero-vram' + ]) + }) + + it('filters by search text, models, tags, and license with debounce handling', async () => { + vi.useFakeTimers() + + const templates = ref([ + { + name: 'api-template', + description: 'Enterprise API workflow for video', + mediaType: 'image', + mediaSubtype: 'png', + tags: ['API', 'Video'], + models: ['Flux'], + date: '2024-06-01', + vram: 15 * 1024 ** 3 + }, + { + name: 'portrait-flow', + description: 'Portrait template tuned for SDXL', + mediaType: 'image', + mediaSubtype: 'png', + tags: ['Portrait'], + models: ['SDXL'], + date: '2024-05-15', + vram: 10 * 1024 ** 3 + }, + { + name: 'landscape-lite', + description: 'Lightweight landscape generator', + mediaType: 'image', + mediaSubtype: 'png', + tags: ['Landscape'], + models: ['SDXL', 'Flux'], + date: '2024-04-20' + } + ]) + + const { + searchQuery, + selectedModels, + selectedUseCases, + selectedLicenses, + filteredTemplates, + availableModels, + availableUseCases, + availableLicenses, + filteredCount, + totalCount, + removeUseCaseFilter, + resetFilters + } = useTemplateFiltering(templates) + + expect(totalCount.value).toBe(3) + expect(availableModels.value).toEqual(['Flux', 'SDXL']) + expect(availableUseCases.value).toEqual([ + 'API', + 'Landscape', + 'Portrait', + 'Video' + ]) + expect(availableLicenses.value).toEqual([ + 'Open Source', + 'Closed Source (API Nodes)' + ]) + + searchQuery.value = 'enterprise' + await nextTick() + await vi.runOnlyPendingTimersAsync() + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'api-template' + ]) + + selectedLicenses.value = ['Closed Source (API Nodes)'] + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'api-template' + ]) + + selectedModels.value = ['Flux'] + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'api-template' + ]) + + selectedUseCases.value = ['Video'] + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'api-template' + ]) + expect(filteredCount.value).toBe(1) + + removeUseCaseFilter('Video') + await nextTick() + expect(selectedUseCases.value).toHaveLength(0) + + resetFilters() + await nextTick() + await vi.runOnlyPendingTimersAsync() + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'api-template', + 'portrait-flow', + 'landscape-lite' + ]) + }) + + it('supports alphabetical, newest, and size-based sorting options', async () => { + const templates = ref([ + { + name: 'zeta-extended', + description: 'older template', + mediaType: 'image', + mediaSubtype: 'png', + date: '2024-01-01', + size: 300 + }, + { + name: 'alpha-starter', + description: 'new template', + mediaType: 'image', + mediaSubtype: 'png', + date: '2024-07-01', + size: 100 + }, + { + name: 'beta-pro', + description: 'mid template', + mediaType: 'image', + mediaSubtype: 'png', + date: '2024-05-01', + size: 200 + } + ]) + + const { sortBy, filteredTemplates } = useTemplateFiltering(templates) + + // default is 'newest' + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'alpha-starter', + 'beta-pro', + 'zeta-extended' + ]) + + sortBy.value = 'alphabetical' + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'alpha-starter', + 'beta-pro', + 'zeta-extended' + ]) + + sortBy.value = 'model-size-low-to-high' + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'alpha-starter', + 'beta-pro', + 'zeta-extended' + ]) + + sortBy.value = 'default' + await nextTick() + expect(filteredTemplates.value.map((template) => template.name)).toEqual([ + 'zeta-extended', + 'alpha-starter', + 'beta-pro' + ]) + }) +})