Validate node def from /object_info endpoint (#159)

* Validate node def

* nit

* nit

* More tests
This commit is contained in:
Chenlei Hu
2024-07-18 12:20:47 -04:00
committed by GitHub
parent 2568746071
commit 9961be1bc7
4 changed files with 132 additions and 14 deletions

View File

@@ -6,7 +6,7 @@
"scripts": {
"dev": "vite",
"build": "npm run typecheck && vite build",
"deploy": "node scripts/deploy.js",
"deploy": "npm run build && node scripts/deploy.js",
"zipdist": "node scripts/zipdist.js",
"typecheck": "tsc --noEmit",
"format": "prettier --write 'src/**/*.{js,ts,tsx,vue}'",

View File

@@ -4,6 +4,7 @@ import {
PendingTaskItem,
RunningTaskItem,
ComfyNodeDef,
validateComfyNodeDef,
} from "@/types/apiTypes";
interface QueuePromptRequestBody {
@@ -240,7 +241,17 @@ class ComfyApi extends EventTarget {
*/
async getNodeDefs(): Promise<Record<string, ComfyNodeDef>> {
const resp = await this.fetchApi("/object_info", { cache: "no-store" });
return await resp.json();
const objectInfoUnsafe = await resp.json();
const objectInfo: Record<string, ComfyNodeDef> = {};
for (const key in objectInfoUnsafe) {
try {
objectInfo[key] = validateComfyNodeDef(objectInfoUnsafe[key]);
} catch (e) {
console.warn("Ignore node definition: ", key);
console.error(e);
}
}
return objectInfo;
}
/**

View File

@@ -1,5 +1,6 @@
import { ZodType, z } from "zod";
import { zComfyWorkflow } from "./comfyWorkflow";
import { fromZodError } from "zod-validation-error";
const zNodeId = z.number();
const zNodeType = z.string();
@@ -124,9 +125,21 @@ export type TaskItem = z.infer<typeof zTaskItem>;
// TODO: validate `/history` `/queue` API endpoint responses.
function inputSpec(spec: [ZodType, ZodType]): ZodType {
function inputSpec(
spec: [ZodType, ZodType],
allowUpcast: boolean = true
): ZodType {
const [inputType, inputSpec] = spec;
return z.union([z.tuple([inputType, inputSpec]), z.tuple([inputType])]);
// e.g. "INT" => ["INT", {}]
const upcastTypes: ZodType[] = allowUpcast
? [inputType.transform((type) => [type, {}])]
: [];
return z.union([
z.tuple([inputType, inputSpec]),
z.tuple([inputType]).transform(([type]) => [type, {}]),
...upcastTypes,
]);
}
const zIntInputSpec = inputSpec([
@@ -173,15 +186,18 @@ const zStringInputSpec = inputSpec([
]);
// Dropdown Selection.
const zComboInputSpec = inputSpec([
z.array(z.any()),
z.object({
default: z.any().optional(),
control_after_generate: z.boolean().optional(),
image_upload: z.boolean().optional(),
forceInput: z.boolean().optional(),
}),
]);
const zComboInputSpec = inputSpec(
[
z.array(z.any()),
z.object({
default: z.any().optional(),
control_after_generate: z.boolean().optional(),
image_upload: z.boolean().optional(),
forceInput: z.boolean().optional(),
}),
],
/* allowUpcast=*/ false
);
const zCustomInputSpec = inputSpec([
z.string(),
@@ -210,6 +226,9 @@ const zComfyNodeDef = z.object({
input: z.object({
required: z.record(zInputSpec).optional(),
optional: z.record(zInputSpec).optional(),
// Frontend repo is not using it, but some custom nodes are using the
// hidden field to pass various values.
hidden: z.record(z.any()).optional(),
}),
output: zComfyOutputSpec,
output_is_list: z.array(z.boolean()),
@@ -227,4 +246,15 @@ export type ComfyInputSpec = z.infer<typeof zInputSpec>;
export type ComfyOutputSpec = z.infer<typeof zComfyOutputSpec>;
export type ComfyNodeDef = z.infer<typeof zComfyNodeDef>;
// TODO: validate `/object_info` API endpoint responses.
export function validateComfyNodeDef(data: any): ComfyNodeDef {
const result = zComfyNodeDef.safeParse(data);
if (!result.success) {
const zodError = fromZodError(result.error);
const error = new Error(
`Invalid ComfyNodeDef: ${JSON.stringify(data)}\n${zodError.message}`
);
error.cause = zodError;
throw error;
}
return result.data;
}

View File

@@ -0,0 +1,77 @@
import { ComfyNodeDef, validateComfyNodeDef } from "@/types/apiTypes";
const fs = require("fs");
const path = require("path");
const EXAMPLE_NODE_DEF: ComfyNodeDef = {
input: {
required: {
ckpt_name: [["model1.safetensors", "model2.ckpt"]],
},
},
output: ["MODEL", "CLIP", "VAE"],
output_is_list: [false, false, false],
output_name: ["MODEL", "CLIP", "VAE"],
name: "CheckpointLoaderSimple",
display_name: "Load Checkpoint",
description: "",
python_module: "nodes",
category: "loaders",
output_node: false,
};
describe("validateNodeDef", () => {
it("Should accept a valid node definition", () => {
expect(() => validateComfyNodeDef(EXAMPLE_NODE_DEF)).not.toThrow();
});
describe.each([
[{ ckpt_name: "foo" }, ["foo", {}]],
[{ ckpt_name: ["foo"] }, ["foo", {}]],
[{ ckpt_name: ["foo", { default: 1 }] }, ["foo", { default: 1 }]],
])(
"validateComfyNodeDef with various input spec formats",
(inputSpec, expected) => {
it(`should accept input spec format: ${JSON.stringify(inputSpec)}`, () => {
expect(
validateComfyNodeDef({
...EXAMPLE_NODE_DEF,
input: {
required: inputSpec,
},
}).input.required.ckpt_name
).toEqual(expected);
});
}
);
describe.each([
[{ ckpt_name: { "model1.safetensors": "foo" } }],
[{ ckpt_name: ["*", ""] }],
[{ ckpt_name: ["foo", { default: 1 }, { default: 2 }] }],
])(
"validateComfyNodeDef rejects with various input spec formats",
(inputSpec) => {
it(`should accept input spec format: ${JSON.stringify(inputSpec)}`, () => {
expect(() =>
validateComfyNodeDef({
...EXAMPLE_NODE_DEF,
input: {
required: inputSpec,
},
})
).toThrow();
});
}
);
it("Should accept all built-in node definitions", async () => {
const nodeDefs = Object.values(
JSON.parse(
fs.readFileSync(path.resolve("./tests-ui/data/object_info.json"))
)
);
nodeDefs.forEach((nodeDef) => {
expect(() => validateComfyNodeDef(nodeDef)).not.toThrow();
});
});
});