diff --git a/src/platform/nodeReplacement/nodeReplacementStore.test.ts b/src/platform/nodeReplacement/nodeReplacementStore.test.ts index a8550ecb7..4523cd8e9 100644 --- a/src/platform/nodeReplacement/nodeReplacementStore.test.ts +++ b/src/platform/nodeReplacement/nodeReplacementStore.test.ts @@ -3,7 +3,9 @@ import type { NodeReplacementResponse } from './types' import { createPinia, setActivePinia } from 'pinia' import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ServerFeatureFlag } from '@/composables/useFeatureFlags' import { useSettingStore } from '@/platform/settings/settingStore' +import { api } from '@/scripts/api' import { fetchNodeReplacements } from './nodeReplacementService' import { useNodeReplacementStore } from './nodeReplacementStore' @@ -15,16 +17,10 @@ vi.mock('./nodeReplacementService', () => ({ fetchNodeReplacements: vi.fn() })) -const mockNodeReplacementsEnabled = vi.hoisted(() => ({ value: true })) - -vi.mock('@/composables/useFeatureFlags', () => ({ - useFeatureFlags: vi.fn(() => ({ - flags: { - get nodeReplacementsEnabled() { - return mockNodeReplacementsEnabled.value - } - } - })) +vi.mock('@/scripts/api', () => ({ + api: { + getServerFeature: vi.fn() + } })) function mockSettingStore(enabled: boolean) { @@ -39,10 +35,17 @@ function mockSettingStore(enabled: boolean) { }) } -function createStore(enabled = true, featureEnabled = true) { +function createStore(settingEnabled = true, serverFeatureEnabled = true) { setActivePinia(createPinia()) - mockSettingStore(enabled) - mockNodeReplacementsEnabled.value = featureEnabled + mockSettingStore(settingEnabled) + vi.mocked(api.getServerFeature).mockImplementation( + (flag: string, defaultValue?: unknown) => { + if (flag === ServerFeatureFlag.NODE_REPLACEMENTS) { + return serverFeatureEnabled + } + return defaultValue + } + ) return useNodeReplacementStore() } @@ -51,8 +54,7 @@ describe('useNodeReplacementStore', () => { beforeEach(() => { vi.clearAllMocks() - mockNodeReplacementsEnabled.value = true - store = createStore(true) + store = createStore() }) it('should initialize with empty replacements', () => { @@ -242,7 +244,7 @@ describe('useNodeReplacementStore', () => { consoleErrorSpy.mockRestore() }) - it('should not fetch when feature is disabled', async () => { + it('should not fetch when setting is disabled', async () => { vi.mocked(fetchNodeReplacements).mockResolvedValue({}) store = createStore(false) @@ -252,6 +254,16 @@ describe('useNodeReplacementStore', () => { expect(store.isLoaded).toBe(false) }) + it('should not fetch when server feature flag is disabled', async () => { + vi.mocked(fetchNodeReplacements).mockResolvedValue(mockReplacements) + store = createStore(true, false) + + await store.load() + + expect(fetchNodeReplacements).not.toHaveBeenCalled() + expect(store.isLoaded).toBe(false) + }) + it('should not re-fetch when called twice', async () => { vi.mocked(fetchNodeReplacements).mockResolvedValue(mockReplacements) store = createStore() @@ -261,25 +273,5 @@ describe('useNodeReplacementStore', () => { expect(fetchNodeReplacements).toHaveBeenCalledOnce() }) - - it('should not call API when setting is disabled', async () => { - vi.mocked(fetchNodeReplacements).mockResolvedValue(mockReplacements) - store = createStore(false) - - await store.load() - - expect(fetchNodeReplacements).not.toHaveBeenCalled() - expect(store.isLoaded).toBe(false) - }) - - it('should not call API when server feature flag is disabled', async () => { - vi.mocked(fetchNodeReplacements).mockResolvedValue(mockReplacements) - store = createStore(true, false) - - await store.load() - - expect(fetchNodeReplacements).not.toHaveBeenCalled() - expect(store.isLoaded).toBe(false) - }) }) }) diff --git a/src/platform/nodeReplacement/nodeReplacementStore.ts b/src/platform/nodeReplacement/nodeReplacementStore.ts index 713bb687f..bf11ba7cb 100644 --- a/src/platform/nodeReplacement/nodeReplacementStore.ts +++ b/src/platform/nodeReplacement/nodeReplacementStore.ts @@ -3,8 +3,9 @@ import type { NodeReplacement, NodeReplacementResponse } from './types' import { defineStore } from 'pinia' import { computed, ref } from 'vue' -import { useFeatureFlags } from '@/composables/useFeatureFlags' +import { ServerFeatureFlag } from '@/composables/useFeatureFlags' import { useSettingStore } from '@/platform/settings/settingStore' +import { api } from '@/scripts/api' import { fetchNodeReplacements } from './nodeReplacementService' export const useNodeReplacementStore = defineStore('nodeReplacement', () => { @@ -15,11 +16,10 @@ export const useNodeReplacementStore = defineStore('nodeReplacement', () => { settingStore.get('Comfy.NodeReplacement.Enabled') ) - const { flags } = useFeatureFlags() - async function load() { if (!isEnabled.value || isLoaded.value) return - if (!flags.nodeReplacementsEnabled) return + if (!api.getServerFeature(ServerFeatureFlag.NODE_REPLACEMENTS, false)) + return try { replacements.value = await fetchNodeReplacements() @@ -42,8 +42,8 @@ export const useNodeReplacementStore = defineStore('nodeReplacement', () => { return { replacements, isLoaded, - load, isEnabled, + load, getReplacementFor, hasReplacement } diff --git a/src/scripts/api.ts b/src/scripts/api.ts index 67357ea6d..c62a19d59 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -700,6 +700,7 @@ export class ComfyApi extends EventTarget { 'Server feature flags received:', this.serverFeatureFlags ) + this.dispatchCustomEvent('feature_flags', msg.data) break default: if (this._registered.has(msg.type)) { diff --git a/src/scripts/app.ts b/src/scripts/app.ts index 052d8cbd9..9aea89988 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -739,6 +739,10 @@ export class ComfyApp { releaseSharedObjectUrl(blobUrl) }) + api.addEventListener('feature_flags', () => { + void useNodeReplacementStore().load() + }) + api.init() } @@ -802,7 +806,6 @@ export class ComfyApp { await useWorkspaceStore().workflow.syncWorkflows() //Doesn't need to block. Blueprints will load async void useSubgraphStore().fetchSubgraphs() - await useNodeReplacementStore().load() await useExtensionService().loadExtensions() this.addProcessKeyHandler()