diff --git a/src/services/nodeSearchService.ts b/src/services/nodeSearchService.ts index b33847872..078abf0b8 100644 --- a/src/services/nodeSearchService.ts +++ b/src/services/nodeSearchService.ts @@ -3,16 +3,23 @@ import { getNodeSource } from '@/types/nodeSource' import Fuse, { IFuseOptions, FuseSearchOptions } from 'fuse.js' import _ from 'lodash' +type SearchAuxScore = [number, number, number, number] + export class FuseSearch { private fuse: Fuse + private readonly keys: string[] public readonly data: T[] + public readonly advancedScoring: boolean constructor( data: T[], options?: IFuseOptions, - createIndex: boolean = true + createIndex: boolean = true, + advancedScoring: boolean = false ) { this.data = data + this.keys = (options.keys ?? []) as string[] + this.advancedScoring = advancedScoring const index = createIndex && options?.keys ? Fuse.createIndex(options.keys, data) @@ -24,7 +31,89 @@ export class FuseSearch { if (!query || query === '') { return [...this.data] } - return this.fuse.search(query, options).map((result) => result.item) + + const fuseResult = this.fuse.search(query, options) + if (!this.advancedScoring) { + return fuseResult.map((x) => x.item) + } + + const aux = fuseResult + .map((x) => ({ + item: x.item, + scores: this.calcAuxScores(query.toLocaleLowerCase(), x.item, x.score) + })) + .sort((a, b) => this.compareAux(a.scores, b.scores)) + + return aux.map((x) => x.item) + } + + public calcAuxScores(query: string, entry: T, score: number) { + let values: string[] = [] + if (!this.keys.length) values = [entry as string] + else values = this.keys.map((x) => entry[x]) + const scores = values.map((x) => this.calcAuxSingle(query, x, score)) + const result = scores.sort(this.compareAux)[0] + + const deprecated = values.some((x) => + x.toLocaleLowerCase().includes('deprecated') + ) + result[0] += deprecated && result[0] != 0 ? 5 : 0 + return result + } + + public calcAuxSingle( + query: string, + item: string, + score: number + ): SearchAuxScore { + const itemWords = item + .split(/ |\b|(?<=[a-z])(?=[A-Z])|(?=[A-Z][a-z])/) + .map((x) => x.toLocaleLowerCase()) + const queryParts = query.split(' ') + item = item.toLocaleLowerCase() + + let main = 9 + let aux1 = 0 + let aux2 = 0 + + if (item == query) { + main = 0 + } else if (item.startsWith(query)) { + main = 1 + aux2 = item.length + } else if (itemWords.includes(query)) { + main = 2 + aux1 = item.indexOf(query) + item.length * 0.5 + aux2 = item.length + } else if (item.includes(query)) { + main = 3 + aux1 = item.indexOf(query) + item.length * 0.5 + aux2 = item.length + } else if (queryParts.every((x) => itemWords.includes(x))) { + const indexes = queryParts.map((x) => itemWords.indexOf(x)) + const min = Math.min(...indexes) + const max = Math.max(...indexes) + main = 4 + aux1 = max - min + max * 0.5 + item.length * 0.5 + aux2 = item.length + } else if (queryParts.every((x) => item.includes(x))) { + const min = Math.min(...queryParts.map((x) => item.indexOf(x))) + const max = Math.max(...queryParts.map((x) => item.indexOf(x) + x.length)) + main = 5 + aux1 = max - min + max * 0.5 + item.length * 0.5 + aux2 = item.length + } + + const lengthPenalty = + 0.2 * + (1 - + Math.min(item.length, query.length) / + Math.max(item.length, query.length)) + return [main, aux1, aux2, score + lengthPenalty] + } + + public compareAux(a: SearchAuxScore, b: SearchAuxScore) { + return a[0] - b[0] || a[1] - b[1] || a[2] - b[2] || a[3] - b[3] } } @@ -116,20 +205,18 @@ export class NodeSearchService { public readonly nodeFilters: NodeFilter[] constructor(data: ComfyNodeDefImpl[]) { - this.nodeFuseSearch = new FuseSearch(data, { - keys: ['name', 'display_name'], - includeScore: true, - threshold: 0.3, - shouldSort: true, - useExtendedSearch: true, - // Sort by score, then by length of the display name, then by index - // Source: https://github.com/Comfy-Org/ComfyUI_frontend/issues/562#issuecomment-2303239027 - sortFn: (a, b) => - Math.min(a.score, b.score) < 0.0001 || - Math.abs(a.score - b.score) > 0.01 - ? a.score - b.score - : a.item[1]['v']['length'] - b.item[1]['v']['length'] || a.idx - b.idx - }) + this.nodeFuseSearch = new FuseSearch( + data, + { + keys: ['name', 'display_name'], + includeScore: true, + threshold: 0.3, + shouldSort: false, + useExtendedSearch: true + }, + true, + true + ) const filterSearchOptions = { includeScore: true,