Compare commits

...

1 Commits

Author SHA1 Message Date
Comfy Org PR Bot
6d1221bc2f [backport cloud/1.44] refactor: align asset pagination schema (#12065)
Backport of #11899 to `cloud/1.44`

Automatically created by backport workflow.

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-12065-backport-cloud-1-44-refactor-align-asset-pagination-schema-3596d73d365081a5b596c288ef8a818a)
by [Unito](https://www.unito.io)

Co-authored-by: jaeone94 <89377375+jaeone94@users.noreply.github.com>
2026-05-07 17:36:36 +00:00
5 changed files with 193 additions and 140 deletions

View File

@@ -215,11 +215,12 @@ export class AssetHelper {
return this.store.size
}
private handleListAssets(route: Route, url: URL) {
const includeTags = url.searchParams.get('include_tags')?.split(',') ?? []
const includeTags = parseAssetTagParam(url.searchParams.get('include_tags'))
const excludeTags = parseAssetTagParam(url.searchParams.get('exclude_tags'))
const limit = parseInt(url.searchParams.get('limit') ?? '0', 10)
const offset = parseInt(url.searchParams.get('offset') ?? '0', 10)
let filtered = this.getFilteredAssets(includeTags)
let filtered = this.getFilteredAssets(includeTags, excludeTags)
if (limit > 0) {
filtered = filtered.slice(offset, offset + limit)
}
@@ -296,15 +297,29 @@ export class AssetHelper {
this.paginationOptions = null
this.uploadResponse = null
}
private getFilteredAssets(tags: string[]): Asset[] {
private getFilteredAssets(
includeTags: string[],
excludeTags: string[]
): Asset[] {
const assets = [...this.store.values()]
if (tags.length === 0) return assets
return assets.filter((asset) =>
tags.every((tag) => (asset.tags ?? []).includes(tag))
return assets.filter(
(asset) =>
includeTags.every((tag) => (asset.tags ?? []).includes(tag)) &&
excludeTags.every((tag) => !(asset.tags ?? []).includes(tag))
)
}
}
function parseAssetTagParam(value: string | null): string[] {
return (
value
?.split(',')
.map((tag) => tag.trim())
.filter(Boolean) ?? []
)
}
export function createAssetHelper(
page: Page,
...operators: AssetOperator[]

View File

@@ -133,6 +133,29 @@ test.describe('AssetHelper', () => {
expect(data.assets[0].id).toBe(STABLE_CHECKPOINT.id)
})
test('GET /assets filters by exclude_tags', async ({
comfyPage,
assetApi
}) => {
assetApi.configure(
withAsset(STABLE_INPUT_IMAGE),
withAsset({
...STABLE_INPUT_IMAGE,
id: 'missing-input',
tags: ['input', 'missing']
})
)
await assetApi.mock()
const { body } = await assetApi.fetch(
`${comfyPage.url}/api/assets?include_tags=input,&exclude_tags= missing,`
)
const data = body as { assets: Array<{ id: string }> }
expect(data.assets.map((asset) => asset.id)).toEqual([
STABLE_INPUT_IMAGE.id
])
})
test('GET /assets/:id returns single asset or 404', async ({
comfyPage,
assetApi

View File

@@ -1,3 +1,4 @@
import { zListAssetsResponse } from '@comfyorg/ingest-types/zod'
import { z } from 'zod'
// Zod schemas for asset API validation matching ComfyUI Assets REST API spec
@@ -20,11 +21,11 @@ const zAsset = z.object({
user_metadata: z.record(z.unknown()).optional() // API allows arbitrary key-value pairs
})
const zAssetResponse = z.object({
assets: z.array(zAsset).optional(),
total: z.number().optional(),
has_more: z.boolean().optional()
})
const zAssetResponse = zListAssetsResponse
.pick({ total: true, has_more: true })
.extend({
assets: z.array(zAsset)
})
const zModelFolder = z.object({
name: z.string(),

View File

@@ -1,6 +1,9 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import type {
AssetItem,
AssetResponse
} from '@/platform/assets/schemas/assetSchema'
import {
MISSING_TAG,
assetService,
@@ -53,6 +56,11 @@ const validBlake3Hash =
'1111111111111111111111111111111111111111111111111111111111111111'
const validBlake3AssetHash = `blake3:${validBlake3Hash}`
type AssetListResponseOptions = {
hasMore?: AssetResponse['has_more']
total?: AssetResponse['total']
}
function buildResponse(
body: unknown,
init: { ok?: boolean; status?: number } = {}
@@ -64,6 +72,13 @@ function buildResponse(
} as unknown as Response
}
function buildAssetListResponse(
assets: AssetItem[],
{ hasMore = false, total = assets.length }: AssetListResponseOptions = {}
): Response {
return buildResponse({ assets, total, has_more: hasMore })
}
function validAsset(overrides: Partial<AssetItem> = {}): AssetItem {
return {
id: 'asset-1',
@@ -218,7 +233,7 @@ describe(assetService.uploadAssetFromUrl, () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildResponse({ id: 'missing-name' }))
await assetService.getInputAssetsIncludingPublic()
@@ -240,7 +255,7 @@ describe(assetService.uploadAssetFromUrl, () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(
buildResponse(validAsset({ id: 'uploaded-input', tags: ['input'] }))
)
@@ -301,7 +316,7 @@ describe(assetService.uploadAssetFromBase64, () => {
.spyOn(globalThis, 'fetch')
.mockResolvedValueOnce(new Response('hello'))
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildResponse({ id: 'missing-name' }))
await assetService.getInputAssetsIncludingPublic()
@@ -327,7 +342,7 @@ describe(assetService.uploadAssetFromBase64, () => {
.spyOn(globalThis, 'fetch')
.mockResolvedValueOnce(new Response('hello'))
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(
buildResponse({
...validAsset({ id: 'uploaded-input', tags: ['input'] }),
@@ -421,17 +436,14 @@ describe(assetService.getAssetModelFolders, () => {
vi.clearAllMocks()
})
it('filters out missing-tagged assets and blacklisted directories, returning alphabetical unique folders without include_public', async () => {
it('requests missing-tag exclusion and returns alphabetical unique folders without include_public', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({
assets: [
validAsset({ id: 'a', tags: ['models', 'loras'] }),
validAsset({ id: 'b', tags: ['models', 'checkpoints'] }),
validAsset({ id: 'c', tags: ['models', 'configs'] }),
validAsset({ id: 'd', tags: ['models', 'missing', 'controlnet'] }),
validAsset({ id: 'e', tags: ['models', 'loras'] })
]
})
buildAssetListResponse([
validAsset({ id: 'a', tags: ['models', 'loras'] }),
validAsset({ id: 'b', tags: ['models', 'checkpoints'] }),
validAsset({ id: 'c', tags: ['models', 'configs'] }),
validAsset({ id: 'e', tags: ['models', 'loras'] })
])
)
const folders = await assetService.getAssetModelFolders()
@@ -444,6 +456,7 @@ describe(assetService.getAssetModelFolders, () => {
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.has('include_public')).toBe(false)
expect(params.get('exclude_tags')).toBe(MISSING_TAG)
})
})
@@ -490,14 +503,9 @@ describe(assetService.getAssetsByTag, () => {
vi.clearAllMocks()
})
it('forwards include_public=true by default and excludes missing-tagged assets', async () => {
it('forwards include_public=true by default and requests missing-tag exclusion', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({
assets: [
validAsset({ id: 'visible', tags: ['input'] }),
validAsset({ id: 'hidden', tags: ['input', 'missing'] })
]
})
buildAssetListResponse([validAsset({ id: 'visible', tags: ['input'] })])
)
const assets = await assetService.getAssetsByTag('input')
@@ -507,6 +515,20 @@ describe(assetService.getAssetsByTag, () => {
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('include_public')).toBe('true')
expect(params.get('exclude_tags')).toBe(MISSING_TAG)
})
it('normalizes tag query parameters', async () => {
fetchApiMock.mockResolvedValueOnce(
buildAssetListResponse([validAsset({ id: 'visible', tags: ['input'] })])
)
await assetService.getAssetsByTag(' input ')
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('include_tags')).toBe('input')
expect(params.get('exclude_tags')).toBe(MISSING_TAG)
})
})
@@ -518,17 +540,16 @@ describe(assetService.getAllAssetsByTag, () => {
it('paginates tagged asset requests with include_public=true', async () => {
fetchApiMock
.mockResolvedValueOnce(
buildResponse({
assets: [
buildAssetListResponse(
[
validAsset({ id: 'a', tags: ['input'] }),
validAsset({ id: 'b', tags: ['input'] })
]
})
],
{ hasMore: true }
)
)
.mockResolvedValueOnce(
buildResponse({
assets: [validAsset({ id: 'c', tags: ['input'] })]
})
buildAssetListResponse([validAsset({ id: 'c', tags: ['input'] })])
)
const assets = await assetService.getAllAssetsByTag('input', true, {
@@ -540,63 +561,33 @@ describe(assetService.getAllAssetsByTag, () => {
const firstUrl = fetchApiMock.mock.calls[0]?.[0] as string
const firstParams = new URL(firstUrl, 'http://localhost').searchParams
expect(firstParams.get('include_public')).toBe('true')
expect(firstParams.get('exclude_tags')).toBe(MISSING_TAG)
expect(firstParams.get('limit')).toBe('2')
expect(firstParams.has('offset')).toBe(false)
const secondUrl = fetchApiMock.mock.calls[1]?.[0] as string
const secondParams = new URL(secondUrl, 'http://localhost').searchParams
expect(secondParams.get('include_public')).toBe('true')
expect(secondParams.get('exclude_tags')).toBe(MISSING_TAG)
expect(secondParams.get('limit')).toBe('2')
expect(secondParams.get('offset')).toBe('2')
})
it('paginates from raw response size before filtering missing-tagged assets', async () => {
fetchApiMock
.mockResolvedValueOnce(
buildResponse({
assets: [
validAsset({ id: 'visible', tags: ['input'] }),
validAsset({ id: 'hidden', tags: ['input', MISSING_TAG] })
]
})
)
.mockResolvedValueOnce(
buildResponse({
assets: [validAsset({ id: 'later-public', tags: ['input'] })]
})
)
const assets = await assetService.getAllAssetsByTag('input', true, {
limit: 2
})
expect(assets.map((a) => a.id)).toEqual(['visible', 'later-public'])
expect(fetchApiMock).toHaveBeenCalledTimes(2)
const secondUrl = fetchApiMock.mock.calls[1]?.[0]
if (typeof secondUrl !== 'string') {
throw new Error('Expected a second asset request URL')
}
const secondParams = new URL(secondUrl, 'http://localhost').searchParams
expect(secondParams.get('offset')).toBe('2')
})
it('honors has_more when walking tagged asset pages', async () => {
fetchApiMock
.mockResolvedValueOnce(
buildResponse({
assets: [
buildAssetListResponse(
[
validAsset({ id: 'first', tags: ['input'] }),
validAsset({ id: 'second', tags: ['input'] })
],
has_more: true
})
{ hasMore: true }
)
)
.mockResolvedValueOnce(
buildResponse({
assets: [validAsset({ id: 'later-public', tags: ['input'] })],
has_more: false
})
buildAssetListResponse([
validAsset({ id: 'later-public', tags: ['input'] })
])
)
const assets = await assetService.getAllAssetsByTag('input', true, {
@@ -614,12 +605,41 @@ describe(assetService.getAllAssetsByTag, () => {
expect(secondParams.get('offset')).toBe('2')
})
it.each([
{
name: 'missing has_more',
body: {
assets: [validAsset({ id: 'a', tags: ['input'] })],
total: 1
}
},
{
name: 'missing total',
body: {
assets: [validAsset({ id: 'a', tags: ['input'] })],
has_more: false
}
},
{
name: 'non-boolean has_more',
body: {
assets: [validAsset({ id: 'a', tags: ['input'] })],
total: 1,
has_more: 'false'
}
}
])('rejects asset responses with $name', async ({ body }) => {
fetchApiMock.mockResolvedValueOnce(buildResponse(body))
await expect(
assetService.getAllAssetsByTag('input', true, { limit: 2 })
).rejects.toThrow(/Invalid asset response/)
})
it('passes abort signals through paginated requests', async () => {
const controller = new AbortController()
fetchApiMock.mockResolvedValueOnce(
buildResponse({
assets: [validAsset({ id: 'a', tags: ['input'] })]
})
buildAssetListResponse([validAsset({ id: 'a', tags: ['input'] })])
)
await assetService.getAllAssetsByTag('input', true, {
@@ -636,12 +656,13 @@ describe(assetService.getAllAssetsByTag, () => {
const controller = new AbortController()
fetchApiMock.mockImplementationOnce(async () => {
controller.abort()
return buildResponse({
assets: [
return buildAssetListResponse(
[
validAsset({ id: 'a', tags: ['input'] }),
validAsset({ id: 'b', tags: ['input'] })
]
})
],
{ hasMore: true }
)
})
await expect(
@@ -666,7 +687,7 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
validAsset({ id: 'user-input', tags: ['input'] }),
validAsset({ id: 'public-input', tags: ['input'], is_immutable: true })
]
fetchApiMock.mockResolvedValueOnce(buildResponse({ assets }))
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse(assets))
const first = await assetService.getInputAssetsIncludingPublic()
const second = await assetService.getInputAssetsIncludingPublic()
@@ -685,8 +706,8 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildResponse({ assets: freshAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
await assetService.getInputAssetsIncludingPublic()
assetService.invalidateInputAssetsIncludingPublic()
@@ -720,7 +741,7 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
await expect(first).rejects.toMatchObject({ name: 'AbortError' })
expect(serviceSignal).toBeUndefined()
resolveResponse(buildResponse({ assets }))
resolveResponse(buildAssetListResponse(assets))
await expect(second).resolves.toEqual(assets)
expect(fetchApiMock).toHaveBeenCalledOnce()
@@ -750,7 +771,7 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
await expect(first).rejects.toMatchObject({ name: 'AbortError' })
await expect(second).rejects.toMatchObject({ name: 'AbortError' })
resolveResponse(buildResponse({ assets }))
resolveResponse(buildAssetListResponse(assets))
await Promise.resolve()
await expect(assetService.getInputAssetsIncludingPublic()).resolves.toEqual(
@@ -770,12 +791,12 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
resolveResponse = resolve
})
)
.mockResolvedValueOnce(buildResponse({ assets: freshAssets }))
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
const inFlight = assetService.getInputAssetsIncludingPublic()
assetService.invalidateInputAssetsIncludingPublic()
resolveResponse(buildResponse({ assets }))
resolveResponse(buildAssetListResponse(assets))
await expect(inFlight).resolves.toEqual(assets)
await expect(assetService.getInputAssetsIncludingPublic()).resolves.toEqual(
@@ -788,9 +809,9 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildResponse(null))
.mockResolvedValueOnce(buildResponse({ assets: freshAssets }))
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
await assetService.getInputAssetsIncludingPublic()
await assetService.deleteAsset('stale-input')
@@ -809,9 +830,9 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
const uploadedAsset = validAsset({ id: 'uploaded-input', tags: ['input'] })
const freshAssets = [uploadedAsset]
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildResponse(uploadedAsset))
.mockResolvedValueOnce(buildResponse({ assets: freshAssets }))
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
await assetService.getInputAssetsIncludingPublic()
await assetService.uploadAssetAsync({
@@ -827,7 +848,7 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
it('does not invalidate cached input assets for pending async input uploads', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(
buildResponse(
{ task_id: 'task-1', status: 'running' },
@@ -849,7 +870,7 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
it('does not invalidate cached input assets for non-input uploads', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
fetchApiMock
.mockResolvedValueOnce(buildResponse({ assets: staleAssets }))
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(buildResponse(validAsset({ tags: ['models'] })))
await assetService.getInputAssetsIncludingPublic()

View File

@@ -36,6 +36,7 @@ interface AssetPaginationOptions extends PaginationOptions {
interface AssetRequestOptions extends PaginationOptions {
includeTags: string[]
excludeTags?: string[]
includePublic?: boolean
signal?: AbortSignal
}
@@ -181,6 +182,7 @@ const INPUT_ASSETS_WITH_PUBLIC_LIMIT = 500
export const MODELS_TAG = 'models'
/** Asset tag used by the backend for placeholder records that are not installed. */
export const MISSING_TAG = 'missing'
const DEFAULT_EXCLUDED_ASSET_TAGS = [MISSING_TAG]
/** Result of a HEAD lookup against an exact asset hash. */
export type AssetHashStatus = 'exists' | 'missing' | 'invalid'
@@ -210,6 +212,10 @@ function throwIfAborted(signal?: AbortSignal): void {
if (signal?.aborted) throw createAbortError()
}
function normalizeAssetTags(tags: string[]): string[] {
return tags.map((tag) => tag.trim()).filter(Boolean)
}
async function withCallerAbort<T>(
promise: Promise<T>,
signal?: AbortSignal
@@ -290,15 +296,22 @@ function createAssetService() {
): Promise<AssetResponse> {
const {
includeTags,
excludeTags = DEFAULT_EXCLUDED_ASSET_TAGS,
limit = DEFAULT_LIMIT,
offset,
includePublic,
signal
} = options
const normalizedIncludeTags = normalizeAssetTags(includeTags)
const normalizedExcludeTags = normalizeAssetTags(excludeTags)
const queryParams = new URLSearchParams({
include_tags: includeTags.join(','),
include_tags: normalizedIncludeTags.join(','),
limit: limit.toString()
})
if (normalizedExcludeTags.length > 0) {
queryParams.set('exclude_tags', normalizedExcludeTags.join(','))
}
if (offset !== undefined && offset > 0) {
queryParams.set('offset', offset.toString())
}
@@ -337,15 +350,10 @@ function createAssetService() {
// Blacklist directories we don't want to show
const blacklistedDirectories = new Set(['configs'])
// Extract directory names from assets that actually exist, exclude missing assets
const discoveredFolders = new Set<string>(
data?.assets
?.filter((asset) => !asset.tags.includes(MISSING_TAG))
?.flatMap((asset) => asset.tags)
?.filter(
(tag) => tag !== MODELS_TAG && !blacklistedDirectories.has(tag)
) ?? []
)
const folderTags = data.assets
.flatMap((asset) => asset.tags)
.filter((tag) => tag !== MODELS_TAG && !blacklistedDirectories.has(tag))
const discoveredFolders = new Set<string>(folderTags)
// Return only discovered folders in alphabetical order
const sortedFolders = Array.from(discoveredFolders).toSorted()
@@ -363,17 +371,10 @@ function createAssetService() {
`models for ${folder}`
)
return (
data?.assets
?.filter(
(asset) =>
!asset.tags.includes(MISSING_TAG) && asset.tags.includes(folder)
)
?.map((asset) => ({
name: asset.name,
pathIndex: 0
})) ?? []
)
return data.assets.map((asset) => ({
name: asset.name,
pathIndex: 0
}))
}
/**
@@ -449,12 +450,7 @@ function createAssetService() {
)
// Return full AssetItem[] objects (don't strip like getAssetModels does)
return (
data?.assets?.filter(
(asset) =>
!asset.tags.includes(MISSING_TAG) && asset.tags.includes(category)
) ?? []
)
return data.assets
}
/**
@@ -473,11 +469,8 @@ function createAssetService() {
}
const data = await res.json()
// Validate the single asset response against our schema
const result = assetResponseSchema.safeParse({ assets: [data] })
if (result.success && result.data.assets?.[0]) {
return result.data.assets[0]
}
const result = assetItemSchema.safeParse(data)
if (result.success) return result.data
const error = result.error
? fromZodError(result.error)
@@ -508,13 +501,12 @@ function createAssetService() {
`assets for tag ${tag}`
)
return (
data?.assets?.filter((asset) => !asset.tags.includes(MISSING_TAG)) ?? []
)
return data.assets
}
/**
* Gets every asset for a tag by walking paginated asset API responses.
* Pagination follows the required server-provided `has_more` flag.
*
* @param tag - The tag to filter by (e.g., 'models', 'input')
* @param includePublic - Whether to include public assets (default: true)
@@ -545,13 +537,14 @@ function createAssetService() {
},
`assets for tag ${tag}`
)
const batch = data.assets ?? []
assets.push(...batch.filter((asset) => !asset.tags.includes(MISSING_TAG)))
const batch = data.assets
if (batch.length === 0) {
return assets
}
const noMoreFromServer = data.has_more === false
const inferredLastPage =
data.has_more === undefined && batch.length < pageSize
if (batch.length === 0 || noMoreFromServer || inferredLastPage) {
assets.push(...batch)
if (!data.has_more) {
return assets
}