Allow control image for omnigen training and sampling

This commit is contained in:
Jaret Burkett
2025-07-09 13:54:55 -06:00
parent bbb57de6ec
commit 611969ec1f
6 changed files with 187 additions and 132 deletions

View File

@@ -310,6 +310,7 @@ export default function SimpleJob({
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'shift', label: 'Shift' },
{ value: 'weighted', label: 'Weighted' },
]}
/>
)}
@@ -541,13 +542,12 @@ export default function SimpleJob({
{ value: 'ddpm', label: 'DDPM' },
]}
/>
</div>
<div>
<NumberInput
label="Guidance Scale"
value={jobConfig.config.process[0].sample.guidance_scale}
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
placeholder="eg. 1.0"
className="pt-2"
min={0}
required
/>
@@ -579,6 +579,26 @@ export default function SimpleJob({
min={0}
required
/>
{isVideoModel && (
<div>
<NumberInput
label="Num Frames"
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
min={0}
required
/>
<NumberInput
label="FPS"
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
min={0}
required
/>
</div>
)}
</div>
<div>
@@ -597,40 +617,36 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
/>
</div>
{isVideoModel && (
<div>
<NumberInput
label="Num Frames"
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
min={0}
required
/>
<NumberInput
label="FPS"
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
min={0}
required
/>
</div>
)}
<div>
<FormGroup label="Advanced Sampling" className="pt-2">
<div>
<Checkbox
label="Skip First Sample"
className="pt-4"
checked={jobConfig.config.process[0].train.skip_first_sample || false}
onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
/>
</div>
<div>
<Checkbox
label="Disable Sampling"
className="pt-1"
checked={jobConfig.config.process[0].train.disable_sampling || false}
onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
/>
</div>
</FormGroup>
</div>
</div>
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
{
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
<p className='font-semibold mb-1'>
Control Images
</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)
}
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className="text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg">
<p className="font-semibold mb-1">Control Images</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className="bg-yellow-900 p-1">make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)}
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<div className="flex-1">

View File

@@ -68,6 +68,8 @@ export const defaultJobConfig: JobConfig = {
use_ema: false,
ema_decay: 0.99,
},
skip_first_sample: false,
disable_sampling: false,
dtype: 'bf16',
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,

View File

@@ -200,6 +200,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
},
].sort((a, b) => {
// Sort by label, case-insensitive

View File

@@ -110,6 +110,8 @@ export interface TrainConfig {
optimizer_params: {
weight_decay: number;
};
skip_first_sample: boolean;
disable_sampling: boolean;
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;