From c00e2fd2084ba0006a2a7486f1300f877b2f19fa Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 31 Jul 2024 09:51:32 -0400 Subject: [PATCH] Allow input spec with extra values passthrough (#273) * Allow input spec extra values passthrough * Refine custom input spec * nit * nit --- src/types/apiTypes.ts | 44 ++++++++++++++++----------------- tests-ui/tests/apiTypes.test.ts | 10 ++++++-- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/types/apiTypes.ts b/src/types/apiTypes.ts index 00c4638b2..bcb0a0073 100644 --- a/src/types/apiTypes.ts +++ b/src/types/apiTypes.ts @@ -175,46 +175,49 @@ function inputSpec( ]) } +const zBaseInputSpecValue = z + .object({ + default: z.any().optional(), + forceInput: z.boolean().optional() + }) + .passthrough() + const zIntInputSpec = inputSpec([ z.literal('INT'), - z.object({ + zBaseInputSpecValue.extend({ min: z.number().optional(), max: z.number().optional(), step: z.number().optional(), - default: z.number().optional(), - forceInput: z.boolean().optional() + default: z.number().optional() }) ]) const zFloatInputSpec = inputSpec([ z.literal('FLOAT'), - z.object({ + zBaseInputSpecValue.extend({ min: z.number().optional(), max: z.number().optional(), step: z.number().optional(), - round: z.number().optional(), - default: z.number().optional(), - forceInput: z.boolean().optional() + round: z.union([z.number(), z.literal(false)]).optional(), + default: z.number().optional() }) ]) const zBooleanInputSpec = inputSpec([ z.literal('BOOLEAN'), - z.object({ + zBaseInputSpecValue.extend({ label_on: z.string().optional(), label_off: z.string().optional(), - default: z.boolean().optional(), - forceInput: z.boolean().optional() + default: z.boolean().optional() }) ]) const zStringInputSpec = inputSpec([ z.literal('STRING'), - z.object({ + zBaseInputSpecValue.extend({ default: z.string().optional(), multiline: z.boolean().optional(), - dynamicPrompts: z.boolean().optional(), - forceInput: z.boolean().optional() + dynamicPrompts: z.boolean().optional() }) ]) @@ -222,22 +225,19 @@ const zStringInputSpec = inputSpec([ const zComboInputSpec = inputSpec( [ z.array(z.any()), - z.object({ - default: z.any().optional(), + zBaseInputSpecValue.extend({ control_after_generate: z.boolean().optional(), - image_upload: z.boolean().optional(), - forceInput: z.boolean().optional() + image_upload: z.boolean().optional() }) ], /* allowUpcast=*/ false ) +const excludedLiterals = new Set(['INT', 'FLOAT', 'BOOLEAN', 'STRING', 'COMBO']) + const zCustomInputSpec = inputSpec([ - z.string(), - z.object({ - default: z.any().optional(), - forceInput: z.boolean().optional() - }) + z.string().refine((value) => !excludedLiterals.has(value)), + zBaseInputSpecValue ]) const zInputSpec = z.union([ diff --git a/tests-ui/tests/apiTypes.test.ts b/tests-ui/tests/apiTypes.test.ts index e3c76d0cf..032a3b853 100644 --- a/tests-ui/tests/apiTypes.test.ts +++ b/tests-ui/tests/apiTypes.test.ts @@ -27,7 +27,11 @@ describe('validateNodeDef', () => { describe.each([ [{ ckpt_name: 'foo' }, ['foo', {}]], [{ ckpt_name: ['foo'] }, ['foo', {}]], - [{ ckpt_name: ['foo', { default: 1 }] }, ['foo', { default: 1 }]] + [{ ckpt_name: ['foo', { default: 1 }] }, ['foo', { default: 1 }]], + // Extra input spec should be preserved + [{ ckpt_name: ['foo', { bar: 1 }] }, ['foo', { bar: 1 }]], + [{ ckpt_name: ['INT', { bar: 1 }] }, ['INT', { bar: 1 }]], + [{ ckpt_name: [[1, 2, 3], { bar: 1 }] }, [[1, 2, 3], { bar: 1 }]] ])( 'validateComfyNodeDef with various input spec formats', (inputSpec, expected) => { @@ -47,7 +51,9 @@ describe('validateNodeDef', () => { describe.each([ [{ ckpt_name: { 'model1.safetensors': 'foo' } }], [{ ckpt_name: ['*', ''] }], - [{ ckpt_name: ['foo', { default: 1 }, { default: 2 }] }] + [{ ckpt_name: ['foo', { default: 1 }, { default: 2 }] }], + // Should reject incorrect default value type. + [{ ckpt_name: ['INT', { default: '124' }] }] ])( 'validateComfyNodeDef rejects with various input spec formats', (inputSpec) => {