diff --git a/src/platform/navigation/preservedQueryManager.ts b/src/platform/navigation/preservedQueryManager.ts new file mode 100644 index 000000000..358d40144 --- /dev/null +++ b/src/platform/navigation/preservedQueryManager.ts @@ -0,0 +1,109 @@ +import type { LocationQuery, LocationQueryRaw } from 'vue-router' + +const STORAGE_PREFIX = 'Comfy.PreservedQuery.' +const preservedQueries = new Map>() + +const readQueryParam = (value: unknown): string | undefined => { + return typeof value === 'string' ? value : undefined +} + +const getStorageKey = (namespace: string) => `${STORAGE_PREFIX}${namespace}` + +const isValidQueryRecord = ( + value: unknown +): value is Record => { + if (typeof value !== 'object' || value === null || Array.isArray(value)) { + return false + } + return Object.values(value).every((v) => typeof v === 'string') +} + +const readFromStorage = (namespace: string): Record | null => { + try { + const raw = sessionStorage.getItem(getStorageKey(namespace)) + if (!raw) return null + + const parsed = JSON.parse(raw) + if (!isValidQueryRecord(parsed)) { + console.warn('[preservedQuery] invalid storage format') + sessionStorage.removeItem(getStorageKey(namespace)) + return null + } + return parsed + } catch (error) { + console.warn('[preservedQuery] storage operation failed') + sessionStorage.removeItem(getStorageKey(namespace)) + return null + } +} + +const writeToStorage = ( + namespace: string, + payload: Record | null +) => { + try { + if (!payload || Object.keys(payload).length === 0) { + sessionStorage.removeItem(getStorageKey(namespace)) + return + } + sessionStorage.setItem(getStorageKey(namespace), JSON.stringify(payload)) + } catch (error) { + console.warn('[preservedQuery] failed to write storage', { + namespace, + error + }) + } +} + +export const hydratePreservedQuery = (namespace: string) => { + if (preservedQueries.has(namespace)) return + const payload = readFromStorage(namespace) + if (payload) { + preservedQueries.set(namespace, payload) + } +} + +export const capturePreservedQuery = ( + namespace: string, + query: LocationQuery, + keys: string[] +) => { + const payload: Record = {} + + keys.forEach((key) => { + const value = readQueryParam(query[key]) + if (value) { + payload[key] = value + } + }) + + if (Object.keys(payload).length === 0) return + + preservedQueries.set(namespace, payload) + writeToStorage(namespace, payload) +} + +export const mergePreservedQueryIntoQuery = ( + namespace: string, + query?: LocationQueryRaw +): LocationQueryRaw | undefined => { + const payload = preservedQueries.get(namespace) + if (!payload) return undefined + + const nextQuery: LocationQueryRaw = { ...(query || {}) } + let changed = false + + for (const [key, value] of Object.entries(payload)) { + if (typeof nextQuery[key] === 'string') continue + nextQuery[key] = value + changed = true + } + + return changed ? nextQuery : undefined +} + +export const clearPreservedQuery = (namespace: string) => { + if (!preservedQueries.has(namespace)) return + preservedQueries.delete(namespace) + writeToStorage(namespace, null) +} diff --git a/src/platform/navigation/preservedQueryNamespaces.ts b/src/platform/navigation/preservedQueryNamespaces.ts new file mode 100644 index 000000000..541f18869 --- /dev/null +++ b/src/platform/navigation/preservedQueryNamespaces.ts @@ -0,0 +1,3 @@ +export const PRESERVED_QUERY_NAMESPACES = { + TEMPLATE: 'template' +} as const diff --git a/src/platform/navigation/preservedQueryTracker.ts b/src/platform/navigation/preservedQueryTracker.ts new file mode 100644 index 000000000..21906194d --- /dev/null +++ b/src/platform/navigation/preservedQueryTracker.ts @@ -0,0 +1,29 @@ +import type { Router } from 'vue-router' + +import { + capturePreservedQuery, + hydratePreservedQuery +} from '@/platform/navigation/preservedQueryManager' + +export const installPreservedQueryTracker = ( + router: Router, + definitions: Array<{ namespace: string; keys: string[] }> +) => { + const trackedDefinitions = definitions.map((definition) => ({ + ...definition + })) + + router.beforeEach((to, _from, next) => { + const queryKeys = new Set(Object.keys(to.query)) + + trackedDefinitions.forEach(({ namespace, keys }) => { + hydratePreservedQuery(namespace) + const shouldCapture = keys.some((key) => queryKeys.has(key)) + if (shouldCapture) { + capturePreservedQuery(namespace, to.query, keys) + } + }) + + next() + }) +} diff --git a/src/platform/workflow/persistence/composables/useWorkflowPersistence.ts b/src/platform/workflow/persistence/composables/useWorkflowPersistence.ts index d628e7a1b..a297b2a79 100644 --- a/src/platform/workflow/persistence/composables/useWorkflowPersistence.ts +++ b/src/platform/workflow/persistence/composables/useWorkflowPersistence.ts @@ -1,7 +1,12 @@ import { tryOnScopeDispose } from '@vueuse/core' import { computed, watch } from 'vue' -import { useRoute } from 'vue-router' +import { useRoute, useRouter } from 'vue-router' +import { + hydratePreservedQuery, + mergePreservedQueryIntoQuery +} from '@/platform/navigation/preservedQueryManager' +import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces' import { useSettingStore } from '@/platform/settings/settingStore' import { useWorkflowService } from '@/platform/workflow/core/services/workflowService' import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore' @@ -15,7 +20,23 @@ export function useWorkflowPersistence() { const workflowStore = useWorkflowStore() const settingStore = useSettingStore() const route = useRoute() + const router = useRouter() const templateUrlLoader = useTemplateUrlLoader() + const TEMPLATE_NAMESPACE = PRESERVED_QUERY_NAMESPACES.TEMPLATE + + const ensureTemplateQueryFromIntent = async () => { + hydratePreservedQuery(TEMPLATE_NAMESPACE) + const mergedQuery = mergePreservedQueryIntoQuery( + TEMPLATE_NAMESPACE, + route.query + ) + + if (mergedQuery) { + await router.replace({ query: mergedQuery }) + } + + return mergedQuery ?? route.query + } const workflowPersistenceEnabled = computed(() => settingStore.get('Comfy.Workflow.Persist') @@ -101,8 +122,8 @@ export function useWorkflowPersistence() { } const loadTemplateFromUrlIfPresent = async () => { - const hasTemplateUrl = - route.query.template && typeof route.query.template === 'string' + const query = await ensureTemplateQueryFromIntent() + const hasTemplateUrl = query.template && typeof query.template === 'string' if (hasTemplateUrl) { await templateUrlLoader.loadTemplateFromUrl() diff --git a/src/platform/workflow/templates/composables/useTemplateUrlLoader.ts b/src/platform/workflow/templates/composables/useTemplateUrlLoader.ts index cccadfe07..ea5f97302 100644 --- a/src/platform/workflow/templates/composables/useTemplateUrlLoader.ts +++ b/src/platform/workflow/templates/composables/useTemplateUrlLoader.ts @@ -2,6 +2,9 @@ import { useToast } from 'primevue/usetoast' import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' +import { clearPreservedQuery } from '@/platform/navigation/preservedQueryManager' +import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces' + import { useTemplateWorkflows } from './useTemplateWorkflows' /** @@ -21,6 +24,7 @@ export function useTemplateUrlLoader() { const { t } = useI18n() const toast = useToast() const templateWorkflows = useTemplateWorkflows() + const TEMPLATE_NAMESPACE = PRESERVED_QUERY_NAMESPACES.TEMPLATE /** * Validates parameter format to prevent path traversal and injection attacks @@ -97,6 +101,7 @@ export function useTemplateUrlLoader() { }) } finally { cleanupUrlParams() + clearPreservedQuery(TEMPLATE_NAMESPACE) } } diff --git a/src/router.ts b/src/router.ts index 9959975d6..bb9addcfd 100644 --- a/src/router.ts +++ b/src/router.ts @@ -14,6 +14,8 @@ import { useUserStore } from '@/stores/userStore' import { isElectron } from '@/utils/envUtil' import LayoutDefault from '@/views/layouts/LayoutDefault.vue' +import { installPreservedQueryTracker } from '@/platform/navigation/preservedQueryTracker' +import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces' import { cloudOnboardingRoutes } from './platform/cloud/onboarding/onboardingCloudRoutes' const isFileProtocol = window.location.protocol === 'file:' @@ -75,6 +77,13 @@ const router = createRouter({ } }) +installPreservedQueryTracker(router, [ + { + namespace: PRESERVED_QUERY_NAMESPACES.TEMPLATE, + keys: ['template', 'source'] + } +]) + if (isCloud) { const PUBLIC_ROUTE_NAMES = new Set([ 'cloud-login', diff --git a/tests-ui/tests/platform/navigation/preservedQueryManager.test.ts b/tests-ui/tests/platform/navigation/preservedQueryManager.test.ts new file mode 100644 index 000000000..5807005b0 --- /dev/null +++ b/tests-ui/tests/platform/navigation/preservedQueryManager.test.ts @@ -0,0 +1,73 @@ +import { beforeEach, describe, expect, it } from 'vitest' + +import { + capturePreservedQuery, + clearPreservedQuery, + hydratePreservedQuery, + mergePreservedQueryIntoQuery +} from '@/platform/navigation/preservedQueryManager' +import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces' + +const NAMESPACE = PRESERVED_QUERY_NAMESPACES.TEMPLATE + +describe('preservedQueryManager', () => { + beforeEach(() => { + sessionStorage.clear() + clearPreservedQuery(NAMESPACE) + }) + + it('captures specified keys from the route query', () => { + capturePreservedQuery(NAMESPACE, { template: 'flux', source: 'custom' }, [ + 'template', + 'source' + ]) + + hydratePreservedQuery(NAMESPACE) + const merged = mergePreservedQueryIntoQuery(NAMESPACE) + + expect(merged).toEqual({ template: 'flux', source: 'custom' }) + expect(sessionStorage.getItem('Comfy.PreservedQuery.template')).toBeTruthy() + }) + + it('hydrates cached payload from sessionStorage once', () => { + sessionStorage.setItem( + 'Comfy.PreservedQuery.template', + JSON.stringify({ template: 'flux', source: 'default' }) + ) + + hydratePreservedQuery(NAMESPACE) + const merged = mergePreservedQueryIntoQuery(NAMESPACE) + + expect(merged).toEqual({ template: 'flux', source: 'default' }) + }) + + it('merges stored payload only when query lacks the keys', () => { + capturePreservedQuery(NAMESPACE, { template: 'flux' }, ['template']) + + const merged = mergePreservedQueryIntoQuery(NAMESPACE, { + foo: 'bar' + }) + + expect(merged).toEqual({ foo: 'bar', template: 'flux' }) + }) + + it('returns undefined when merge does not change query', () => { + capturePreservedQuery(NAMESPACE, { template: 'flux' }, ['template']) + + const merged = mergePreservedQueryIntoQuery(NAMESPACE, { + template: 'existing' + }) + + expect(merged).toBeUndefined() + }) + + it('clears cached payload', () => { + capturePreservedQuery(NAMESPACE, { template: 'flux' }, ['template']) + + clearPreservedQuery(NAMESPACE) + + const merged = mergePreservedQueryIntoQuery(NAMESPACE) + expect(merged).toBeUndefined() + expect(sessionStorage.getItem('Comfy.PreservedQuery.template')).toBeNull() + }) +}) diff --git a/tests-ui/tests/platform/workflow/templates/composables/useTemplateUrlLoader.test.ts b/tests-ui/tests/platform/workflow/templates/composables/useTemplateUrlLoader.test.ts index 10127a47f..5036cd730 100644 --- a/tests-ui/tests/platform/workflow/templates/composables/useTemplateUrlLoader.test.ts +++ b/tests-ui/tests/platform/workflow/templates/composables/useTemplateUrlLoader.test.ts @@ -12,6 +12,10 @@ import { useTemplateUrlLoader } from '@/platform/workflow/templates/composables/ * - Input validation for template and source parameters */ +const preservedQueryMocks = vi.hoisted(() => ({ + clearPreservedQuery: vi.fn() +})) + // Mock vue-router let mockQueryParams: Record = {} const mockRouterReplace = vi.fn() @@ -25,6 +29,11 @@ vi.mock('vue-router', () => ({ })) })) +vi.mock( + '@/platform/navigation/preservedQueryManager', + () => preservedQueryMocks +) + // Mock template workflows composable const mockLoadTemplates = vi.fn().mockResolvedValue(true) const mockLoadWorkflowTemplate = vi.fn().mockResolvedValue(true) @@ -88,6 +97,7 @@ describe('useTemplateUrlLoader', () => { 'flux_simple', 'default' ) + expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledTimes(1) }) it('uses default source when source param is not provided', async () => {