Allow input spec with extra values passthrough (#273)

* Allow input spec extra values passthrough

* Refine custom input spec

* nit

* nit
This commit is contained in:
Chenlei Hu
2024-07-31 09:51:32 -04:00
committed by GitHub
parent d77343da83
commit c00e2fd208
2 changed files with 30 additions and 24 deletions

View File

@@ -175,46 +175,49 @@ function inputSpec(
])
}
const zBaseInputSpecValue = z
.object({
default: z.any().optional(),
forceInput: z.boolean().optional()
})
.passthrough()
const zIntInputSpec = inputSpec([
z.literal('INT'),
z.object({
zBaseInputSpecValue.extend({
min: z.number().optional(),
max: z.number().optional(),
step: z.number().optional(),
default: z.number().optional(),
forceInput: z.boolean().optional()
default: z.number().optional()
})
])
const zFloatInputSpec = inputSpec([
z.literal('FLOAT'),
z.object({
zBaseInputSpecValue.extend({
min: z.number().optional(),
max: z.number().optional(),
step: z.number().optional(),
round: z.number().optional(),
default: z.number().optional(),
forceInput: z.boolean().optional()
round: z.union([z.number(), z.literal(false)]).optional(),
default: z.number().optional()
})
])
const zBooleanInputSpec = inputSpec([
z.literal('BOOLEAN'),
z.object({
zBaseInputSpecValue.extend({
label_on: z.string().optional(),
label_off: z.string().optional(),
default: z.boolean().optional(),
forceInput: z.boolean().optional()
default: z.boolean().optional()
})
])
const zStringInputSpec = inputSpec([
z.literal('STRING'),
z.object({
zBaseInputSpecValue.extend({
default: z.string().optional(),
multiline: z.boolean().optional(),
dynamicPrompts: z.boolean().optional(),
forceInput: z.boolean().optional()
dynamicPrompts: z.boolean().optional()
})
])
@@ -222,22 +225,19 @@ const zStringInputSpec = inputSpec([
const zComboInputSpec = inputSpec(
[
z.array(z.any()),
z.object({
default: z.any().optional(),
zBaseInputSpecValue.extend({
control_after_generate: z.boolean().optional(),
image_upload: z.boolean().optional(),
forceInput: z.boolean().optional()
image_upload: z.boolean().optional()
})
],
/* allowUpcast=*/ false
)
const excludedLiterals = new Set(['INT', 'FLOAT', 'BOOLEAN', 'STRING', 'COMBO'])
const zCustomInputSpec = inputSpec([
z.string(),
z.object({
default: z.any().optional(),
forceInput: z.boolean().optional()
})
z.string().refine((value) => !excludedLiterals.has(value)),
zBaseInputSpecValue
])
const zInputSpec = z.union([

View File

@@ -27,7 +27,11 @@ describe('validateNodeDef', () => {
describe.each([
[{ ckpt_name: 'foo' }, ['foo', {}]],
[{ ckpt_name: ['foo'] }, ['foo', {}]],
[{ ckpt_name: ['foo', { default: 1 }] }, ['foo', { default: 1 }]]
[{ ckpt_name: ['foo', { default: 1 }] }, ['foo', { default: 1 }]],
// Extra input spec should be preserved
[{ ckpt_name: ['foo', { bar: 1 }] }, ['foo', { bar: 1 }]],
[{ ckpt_name: ['INT', { bar: 1 }] }, ['INT', { bar: 1 }]],
[{ ckpt_name: [[1, 2, 3], { bar: 1 }] }, [[1, 2, 3], { bar: 1 }]]
])(
'validateComfyNodeDef with various input spec formats',
(inputSpec, expected) => {
@@ -47,7 +51,9 @@ describe('validateNodeDef', () => {
describe.each([
[{ ckpt_name: { 'model1.safetensors': 'foo' } }],
[{ ckpt_name: ['*', ''] }],
[{ ckpt_name: ['foo', { default: 1 }, { default: 2 }] }]
[{ ckpt_name: ['foo', { default: 1 }, { default: 2 }] }],
// Should reject incorrect default value type.
[{ ckpt_name: ['INT', { default: '124' }] }]
])(
'validateComfyNodeDef rejects with various input spec formats',
(inputSpec) => {