diff --git a/package.json b/package.json index ffadc2c61..c95c2bb8a 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,7 @@ "scripts": { "dev": "vite", "build": "npm run typecheck && vite build", - "deploy": "node scripts/deploy.js", + "deploy": "npm run build && node scripts/deploy.js", "zipdist": "node scripts/zipdist.js", "typecheck": "tsc --noEmit", "format": "prettier --write 'src/**/*.{js,ts,tsx,vue}'", diff --git a/src/scripts/api.ts b/src/scripts/api.ts index 1001c90c6..3f8bc3d14 100644 --- a/src/scripts/api.ts +++ b/src/scripts/api.ts @@ -4,6 +4,7 @@ import { PendingTaskItem, RunningTaskItem, ComfyNodeDef, + validateComfyNodeDef, } from "@/types/apiTypes"; interface QueuePromptRequestBody { @@ -240,7 +241,17 @@ class ComfyApi extends EventTarget { */ async getNodeDefs(): Promise> { const resp = await this.fetchApi("/object_info", { cache: "no-store" }); - return await resp.json(); + const objectInfoUnsafe = await resp.json(); + const objectInfo: Record = {}; + for (const key in objectInfoUnsafe) { + try { + objectInfo[key] = validateComfyNodeDef(objectInfoUnsafe[key]); + } catch (e) { + console.warn("Ignore node definition: ", key); + console.error(e); + } + } + return objectInfo; } /** diff --git a/src/types/apiTypes.ts b/src/types/apiTypes.ts index 4c3b167da..98c3bec37 100644 --- a/src/types/apiTypes.ts +++ b/src/types/apiTypes.ts @@ -1,5 +1,6 @@ import { ZodType, z } from "zod"; import { zComfyWorkflow } from "./comfyWorkflow"; +import { fromZodError } from "zod-validation-error"; const zNodeId = z.number(); const zNodeType = z.string(); @@ -124,9 +125,21 @@ export type TaskItem = z.infer; // TODO: validate `/history` `/queue` API endpoint responses. -function inputSpec(spec: [ZodType, ZodType]): ZodType { +function inputSpec( + spec: [ZodType, ZodType], + allowUpcast: boolean = true +): ZodType { const [inputType, inputSpec] = spec; - return z.union([z.tuple([inputType, inputSpec]), z.tuple([inputType])]); + // e.g. "INT" => ["INT", {}] + const upcastTypes: ZodType[] = allowUpcast + ? [inputType.transform((type) => [type, {}])] + : []; + + return z.union([ + z.tuple([inputType, inputSpec]), + z.tuple([inputType]).transform(([type]) => [type, {}]), + ...upcastTypes, + ]); } const zIntInputSpec = inputSpec([ @@ -173,15 +186,18 @@ const zStringInputSpec = inputSpec([ ]); // Dropdown Selection. -const zComboInputSpec = inputSpec([ - z.array(z.any()), - z.object({ - default: z.any().optional(), - control_after_generate: z.boolean().optional(), - image_upload: z.boolean().optional(), - forceInput: z.boolean().optional(), - }), -]); +const zComboInputSpec = inputSpec( + [ + z.array(z.any()), + z.object({ + default: z.any().optional(), + control_after_generate: z.boolean().optional(), + image_upload: z.boolean().optional(), + forceInput: z.boolean().optional(), + }), + ], + /* allowUpcast=*/ false +); const zCustomInputSpec = inputSpec([ z.string(), @@ -210,6 +226,9 @@ const zComfyNodeDef = z.object({ input: z.object({ required: z.record(zInputSpec).optional(), optional: z.record(zInputSpec).optional(), + // Frontend repo is not using it, but some custom nodes are using the + // hidden field to pass various values. + hidden: z.record(z.any()).optional(), }), output: zComfyOutputSpec, output_is_list: z.array(z.boolean()), @@ -227,4 +246,15 @@ export type ComfyInputSpec = z.infer; export type ComfyOutputSpec = z.infer; export type ComfyNodeDef = z.infer; -// TODO: validate `/object_info` API endpoint responses. +export function validateComfyNodeDef(data: any): ComfyNodeDef { + const result = zComfyNodeDef.safeParse(data); + if (!result.success) { + const zodError = fromZodError(result.error); + const error = new Error( + `Invalid ComfyNodeDef: ${JSON.stringify(data)}\n${zodError.message}` + ); + error.cause = zodError; + throw error; + } + return result.data; +} diff --git a/tests-ui/tests/apiTypes.test.ts b/tests-ui/tests/apiTypes.test.ts new file mode 100644 index 000000000..335793713 --- /dev/null +++ b/tests-ui/tests/apiTypes.test.ts @@ -0,0 +1,77 @@ +import { ComfyNodeDef, validateComfyNodeDef } from "@/types/apiTypes"; +const fs = require("fs"); +const path = require("path"); + +const EXAMPLE_NODE_DEF: ComfyNodeDef = { + input: { + required: { + ckpt_name: [["model1.safetensors", "model2.ckpt"]], + }, + }, + output: ["MODEL", "CLIP", "VAE"], + output_is_list: [false, false, false], + output_name: ["MODEL", "CLIP", "VAE"], + name: "CheckpointLoaderSimple", + display_name: "Load Checkpoint", + description: "", + python_module: "nodes", + category: "loaders", + output_node: false, +}; + +describe("validateNodeDef", () => { + it("Should accept a valid node definition", () => { + expect(() => validateComfyNodeDef(EXAMPLE_NODE_DEF)).not.toThrow(); + }); + + describe.each([ + [{ ckpt_name: "foo" }, ["foo", {}]], + [{ ckpt_name: ["foo"] }, ["foo", {}]], + [{ ckpt_name: ["foo", { default: 1 }] }, ["foo", { default: 1 }]], + ])( + "validateComfyNodeDef with various input spec formats", + (inputSpec, expected) => { + it(`should accept input spec format: ${JSON.stringify(inputSpec)}`, () => { + expect( + validateComfyNodeDef({ + ...EXAMPLE_NODE_DEF, + input: { + required: inputSpec, + }, + }).input.required.ckpt_name + ).toEqual(expected); + }); + } + ); + + describe.each([ + [{ ckpt_name: { "model1.safetensors": "foo" } }], + [{ ckpt_name: ["*", ""] }], + [{ ckpt_name: ["foo", { default: 1 }, { default: 2 }] }], + ])( + "validateComfyNodeDef rejects with various input spec formats", + (inputSpec) => { + it(`should accept input spec format: ${JSON.stringify(inputSpec)}`, () => { + expect(() => + validateComfyNodeDef({ + ...EXAMPLE_NODE_DEF, + input: { + required: inputSpec, + }, + }) + ).toThrow(); + }); + } + ); + + it("Should accept all built-in node definitions", async () => { + const nodeDefs = Object.values( + JSON.parse( + fs.readFileSync(path.resolve("./tests-ui/data/object_info.json")) + ) + ); + nodeDefs.forEach((nodeDef) => { + expect(() => validateComfyNodeDef(nodeDef)).not.toThrow(); + }); + }); +});