From 132a9dbb5ff1bc11c68daf9ef45869922e3ee8ba Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 8 Jul 2025 00:23:10 -0700 Subject: [PATCH] Add a feature flags message to reduce bandwidth --- src/config/clientFeatureFlags.json | 3 + src/schemas/apiSchema.ts | 3 + src/scripts/api.ts | 82 ++++++++++ tests-ui/tests/api.featureFlags.test.ts | 191 ++++++++++++++++++++++++ 4 files changed, 279 insertions(+) create mode 100644 src/config/clientFeatureFlags.json create mode 100644 tests-ui/tests/api.featureFlags.test.ts diff --git a/src/config/clientFeatureFlags.json b/src/config/clientFeatureFlags.json new file mode 100644 index 000000000..84a233ccf --- /dev/null +++ b/src/config/clientFeatureFlags.json @@ -0,0 +1,3 @@ +{ + "supports_preview_metadata": true +} diff --git a/src/schemas/apiSchema.ts b/src/schemas/apiSchema.ts index c5979c6bd..ab7b21a8c 100644 --- a/src/schemas/apiSchema.ts +++ b/src/schemas/apiSchema.ts @@ -129,6 +129,8 @@ const zLogRawResponse = z.object({ entries: z.array(zLogEntry) }) +const zFeatureFlagsWsMessage = z.record(z.string(), z.any()) + export type StatusWsMessageStatus = z.infer export type StatusWsMessage = z.infer export type ProgressWsMessage = z.infer @@ -150,6 +152,7 @@ export type DisplayComponentWsMessage = z.infer< > export type NodeProgressState = z.infer export type ProgressStateWsMessage = z.infer +export type FeatureFlagsWsMessage = z.infer // End of ws messages const zPromptInputItem = z.object({ diff --git a/src/scripts/api.ts b/src/scripts/api.ts index 84048c906..6015a21aa 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -1,5 +1,6 @@ import axios from 'axios' +import defaultClientFeatureFlags from '@/config/clientFeatureFlags.json' import type { DisplayComponentWsMessage, EmbeddingsResponse, @@ -11,6 +12,7 @@ import type { ExecutionStartWsMessage, ExecutionSuccessWsMessage, ExtensionsResponse, + FeatureFlagsWsMessage, HistoryTaskItem, LogsRawResponse, LogsWsMessage, @@ -116,6 +118,7 @@ interface BackendApiCalls { progress_text: ProgressTextWsMessage progress_state: ProgressStateWsMessage display_component: DisplayComponentWsMessage + feature_flags: FeatureFlagsWsMessage } /** Dictionary of all api calls */ @@ -245,6 +248,27 @@ export class ComfyApi extends EventTarget { reportedUnknownMessageTypes = new Set() + /** + * Feature flags supported by this frontend client. + */ + clientFeatureFlags: Record = { ...defaultClientFeatureFlags } + + /** + * Feature flags received from the backend server. + */ + serverFeatureFlags: Record = {} + + /** + * Alias for serverFeatureFlags for test compatibility. + */ + get feature_flags() { + return this.serverFeatureFlags + } + + set feature_flags(value: Record) { + this.serverFeatureFlags = value + } + /** * The auth token for the comfy org account if the user is logged in. * This is only used for {@link queuePrompt} now. It is not directly @@ -386,6 +410,15 @@ export class ComfyApi extends EventTarget { this.socket.addEventListener('open', () => { opened = true + + // Send feature flags as the first message + this.socket!.send( + JSON.stringify({ + type: 'feature_flags', + data: this.clientFeatureFlags + }) + ) + if (isReconnect) { this.dispatchCustomEvent('reconnected') } @@ -507,6 +540,14 @@ export class ComfyApi extends EventTarget { case 'b_preview': this.dispatchCustomEvent(msg.type, msg.data) break + case 'feature_flags': + // Store server feature flags + this.serverFeatureFlags = msg.data + console.log( + 'Server feature flags received:', + this.serverFeatureFlags + ) + break default: if (this.#registered.has(msg.type)) { // Fallback for custom types - calls super direct. @@ -995,6 +1036,47 @@ export class ComfyApi extends EventTarget { async getCustomNodesI18n(): Promise> { return (await axios.get(this.apiURL('/i18n'))).data } + + /** + * Checks if the server supports a specific feature. + * @param featureName The name of the feature to check + * @returns true if the feature is supported, false otherwise + */ + serverSupportsFeature(featureName: string): boolean { + return this.serverFeatureFlags[featureName] === true + } + + /** + * Gets a server feature flag value. + * @param featureName The name of the feature to get + * @param defaultValue The default value if the feature is not found + * @returns The feature value or default + */ + getServerFeature(featureName: string, defaultValue?: T): T { + return this.serverFeatureFlags[featureName] ?? defaultValue + } + + /** + * Gets all server feature flags. + * @returns Copy of all server feature flags + */ + getServerFeatures(): Record { + return { ...this.serverFeatureFlags } + } + + /** + * Updates the client feature flags. + * + * This is intentionally disabled for now. When we introduce an official Public API + * for the frontend, we'll introduce a function for custom frontend extensions to + * add their own feature flags in a way that won't interfere with other extensions + * or the builtin frontend flags. + */ + /* + setClientFeatureFlags(flags: Record): void { + this.clientFeatureFlags = flags + } + */ } export const api = new ComfyApi() diff --git a/tests-ui/tests/api.featureFlags.test.ts b/tests-ui/tests/api.featureFlags.test.ts new file mode 100644 index 000000000..4396fcef0 --- /dev/null +++ b/tests-ui/tests/api.featureFlags.test.ts @@ -0,0 +1,191 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { api } from '@/scripts/api' + +describe('API Feature Flags', () => { + let mockWebSocket: any + const wsEventHandlers: { [key: string]: (event: any) => void } = {} + + beforeEach(() => { + // Use fake timers + vi.useFakeTimers() + + // Mock WebSocket + mockWebSocket = { + readyState: 1, // WebSocket.OPEN + send: vi.fn(), + close: vi.fn(), + addEventListener: vi.fn( + (event: string, handler: (event: any) => void) => { + wsEventHandlers[event] = handler + } + ), + removeEventListener: vi.fn() + } + + // Mock WebSocket constructor + global.WebSocket = vi.fn().mockImplementation(() => mockWebSocket) as any + + // Reset API state + api.feature_flags = {} + api.clientFeatureFlags = { + supports_preview_metadata: true, + api_version: '1.0.0', + capabilities: ['bulk_operations', 'async_nodes'] + } + }) + + afterEach(() => { + vi.useRealTimers() + vi.restoreAllMocks() + }) + + describe('Feature flags negotiation', () => { + it('should send client feature flags as first message on connection', async () => { + // Initialize API connection + const initPromise = api.init() + + // Simulate connection open + wsEventHandlers['open'](new Event('open')) + + // Check that feature flags were sent as first message + expect(mockWebSocket.send).toHaveBeenCalledTimes(1) + const sentMessage = JSON.parse(mockWebSocket.send.mock.calls[0][0]) + expect(sentMessage).toEqual({ + type: 'feature_flags', + data: { + supports_preview_metadata: true, + api_version: '1.0.0', + capabilities: ['bulk_operations', 'async_nodes'] + } + }) + + // Simulate server response with status message + wsEventHandlers['message']({ + data: JSON.stringify({ + type: 'status', + data: { + status: { exec_info: { queue_remaining: 0 } }, + sid: 'test-sid' + } + }) + }) + + // Simulate server feature flags response + wsEventHandlers['message']({ + data: JSON.stringify({ + type: 'feature_flags', + data: { + supports_preview_metadata: true, + async_execution: true, + supported_formats: ['webp', 'jpeg', 'png'], + api_version: '1.0.0', + max_upload_size: 104857600, + capabilities: ['isolated_nodes', 'dynamic_models'] + } + }) + }) + + await initPromise + + // Check that server features were stored + expect(api.feature_flags).toEqual({ + supports_preview_metadata: true, + async_execution: true, + supported_formats: ['webp', 'jpeg', 'png'], + api_version: '1.0.0', + max_upload_size: 104857600, + capabilities: ['isolated_nodes', 'dynamic_models'] + }) + }) + + it('should handle server without feature flags support', async () => { + // Initialize API connection + const initPromise = api.init() + + // Simulate connection open + wsEventHandlers['open'](new Event('open')) + + // Clear the send mock to reset + mockWebSocket.send.mockClear() + + // Simulate server response with status but no feature flags + wsEventHandlers['message']({ + data: JSON.stringify({ + type: 'status', + data: { + status: { exec_info: { queue_remaining: 0 } }, + sid: 'test-sid' + } + }) + }) + + // Simulate some other message (not feature flags) + wsEventHandlers['message']({ + data: JSON.stringify({ + type: 'execution_start', + data: {} + }) + }) + + await initPromise + + // Server features should remain empty + expect(api.feature_flags).toEqual({}) + }) + }) + + describe('Feature checking methods', () => { + beforeEach(() => { + // Set up some test features + api.feature_flags = { + supports_preview_metadata: true, + async_execution: false, + capabilities: ['isolated_nodes', 'dynamic_models'] + } + }) + + it('should check if server supports a boolean feature', () => { + expect(api.serverSupportsFeature('supports_preview_metadata')).toBe(true) + expect(api.serverSupportsFeature('async_execution')).toBe(false) + expect(api.serverSupportsFeature('non_existent_feature')).toBe(false) + }) + + it('should get server feature value', () => { + expect(api.getServerFeature('supports_preview_metadata')).toBe(true) + expect(api.getServerFeature('capabilities')).toEqual([ + 'isolated_nodes', + 'dynamic_models' + ]) + expect(api.getServerFeature('non_existent_feature')).toBeUndefined() + }) + }) + + describe('Client feature flags configuration', () => { + it('should use default client feature flags', () => { + // Verify default flags are loaded from config + expect(api.clientFeatureFlags).toHaveProperty( + 'supports_preview_metadata', + true + ) + expect(api.clientFeatureFlags).toHaveProperty('api_version', '1.0.0') + expect(api.clientFeatureFlags).toHaveProperty('capabilities') + expect(api.clientFeatureFlags.capabilities).toEqual([ + 'bulk_operations', + 'async_nodes' + ]) + }) + }) + + describe('Integration with preview messages', () => { + it('should affect preview message handling based on feature support', () => { + // Test with metadata support + api.feature_flags = { supports_preview_metadata: true } + expect(api.serverSupportsFeature('supports_preview_metadata')).toBe(true) + + // Test without metadata support + api.feature_flags = {} + expect(api.serverSupportsFeature('supports_preview_metadata')).toBe(false) + }) + }) +})