diff --git a/src/types/comfyWorkflow.ts b/src/types/comfyWorkflow.ts index f68a5b31ed..370918d911 100644 --- a/src/types/comfyWorkflow.ts +++ b/src/types/comfyWorkflow.ts @@ -1,4 +1,4 @@ -import { z } from 'zod' +import { z, type SafeParseReturnType } from 'zod' import { fromZodError } from 'zod-validation-error' // GroupNode is hacking node id to be a string, so we need to allow that. @@ -21,6 +21,14 @@ export const zSlotIndex = z.union([ // - https://github.com/rgthree/rgthree-comfy Context Big node is using array as type. export const zDataType = z.union([z.string(), z.array(z.string()), z.number()]) +const zVector2 = z.union([ + z + .object({ 0: z.number(), 1: z.number() }) + .passthrough() + .transform((v) => [v[0], v[1]]), + z.tuple([z.number(), z.number()]) +]) + // Definition of an AI model file used in the workflow. const zModelFile = z.object({ name: z.string(), @@ -30,6 +38,13 @@ const zModelFile = z.object({ directory: z.string() }) +const zGraphState = z.object({ + lastGroupid: z.number().optional(), + lastNodeId: z.number().optional(), + lastLinkId: z.number().optional(), + lastRerouteId: z.number().optional() +}) + const zComfyLink = z.tuple([ z.number(), // Link id zNodeId, // Node id of source node @@ -39,6 +54,23 @@ const zComfyLink = z.tuple([ zDataType // Data type ]) +const zComfyLinkObject = z.object({ + id: z.number(), + origin_id: zNodeId, + origin_slot: zSlotIndex, + target_id: zNodeId, + target_slot: zSlotIndex, + type: zDataType, + parentId: z.number().optional() +}) + +const zReroute = z.object({ + id: z.number(), + parentId: z.number().optional(), + pos: zVector2, + linkIds: z.array(z.number()).nullish() +}) + const zNodeOutput = z .object({ name: z.string(), @@ -73,14 +105,6 @@ const zProperties = z }) .passthrough() -const zVector2 = z.union([ - z - .object({ 0: z.number(), 1: z.number() }) - .passthrough() - .transform((v) => [v[0], v[1]]), - z.tuple([z.number(), z.number()]) -]) - const zWidgetValues = z.union([z.array(z.any()), z.record(z.any())]) const zComfyNode = z @@ -158,21 +182,54 @@ export const zComfyWorkflow = z }) .passthrough() +const zComfyWorkflow1 = z + .object({ + version: z.number(), + config: zConfig.optional().nullable(), + state: zGraphState, + groups: z.array(zGroup).optional(), + nodes: z.array(zComfyNode), + links: z.array(zComfyLinkObject).optional(), + reroutes: z.array(zReroute).optional(), + extra: zExtra.optional().nullable(), + models: z.array(zModelFile).optional() + }) + .passthrough() + export type NodeInput = z.infer export type NodeOutput = z.infer export type ComfyLink = z.infer export type ComfyNode = z.infer -export type ComfyWorkflowJSON = z.infer +export type ComfyWorkflowJSON = z.infer< + typeof zComfyWorkflow | typeof zComfyWorkflow1 +> + +const zWorkflowVersion = z.object({ + version: z.number() +}) export async function validateComfyWorkflow( - data: any, + data: unknown, onError: (error: string) => void = console.warn ): Promise { - const result = await zComfyWorkflow.safeParseAsync(data) - if (!result.success) { - const error = fromZodError(result.error) - onError(`Invalid workflow against zod schema:\n${error}`) + const versionResult = zWorkflowVersion.safeParse(data) + + let result: SafeParseReturnType + if (!versionResult.success) { + // Invalid workflow + const error = fromZodError(versionResult.error) + onError(`Workflow does not contain a valid version. Zod error:\n${error}`) return null + } else if (versionResult.data.version === 1) { + // Schema version 1 + result = await zComfyWorkflow1.safeParseAsync(data) + } else { + // Unknown or old version: 0.4 + result = await zComfyWorkflow.safeParseAsync(data) } - return result.data + if (result.success) return result.data + + const error = fromZodError(result.error) + onError(`Invalid workflow against zod schema:\n${error}`) + return null } diff --git a/tests-ui/tests/fast/comfyWorkflow.test.ts b/tests-ui/tests/fast/comfyWorkflow.test.ts index 4b86663478..8e27f3ff6a 100644 --- a/tests-ui/tests/fast/comfyWorkflow.test.ts +++ b/tests-ui/tests/fast/comfyWorkflow.test.ts @@ -32,10 +32,11 @@ describe('parseComfyWorkflow', () => { workflow.version = undefined expect(await validateComfyWorkflow(workflow)).toBeNull() - workflow.version = '1.0.1' // Invalid format. + workflow.version = '1.0.1' // Invalid format (string) expect(await validateComfyWorkflow(workflow)).toBeNull() - workflow.version = 1 + // 2018-2024 schema: 0.4 + workflow.version = 0.4 expect(await validateComfyWorkflow(workflow)).not.toBeNull() })