mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Allow control image for omnigen training and sampling
This commit is contained in:
@@ -310,6 +310,7 @@ export default function SimpleJob({
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'shift', label: 'Shift' },
|
||||
{ value: 'weighted', label: 'Weighted' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
@@ -541,13 +542,12 @@ export default function SimpleJob({
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Guidance Scale"
|
||||
value={jobConfig.config.process[0].sample.guidance_scale}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
||||
placeholder="eg. 1.0"
|
||||
className="pt-2"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
@@ -579,6 +579,26 @@ export default function SimpleJob({
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
{isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div>
|
||||
@@ -597,40 +617,36 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
{isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<FormGroup label="Advanced Sampling" className="pt-2">
|
||||
<div>
|
||||
<Checkbox
|
||||
label="Skip First Sample"
|
||||
className="pt-4"
|
||||
checked={jobConfig.config.process[0].train.skip_first_sample || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<Checkbox
|
||||
label="Disable Sampling"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.disable_sampling || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
|
||||
{
|
||||
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
|
||||
<p className='font-semibold mb-1'>
|
||||
Control Images
|
||||
</p>
|
||||
To use control images on samples, add --ctrl_img to the prompts below.
|
||||
<br />
|
||||
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div className="text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg">
|
||||
<p className="font-semibold mb-1">Control Images</p>
|
||||
To use control images on samples, add --ctrl_img to the prompts below.
|
||||
<br />
|
||||
Example: <code className="bg-yellow-900 p-1">make this a cartoon --ctrl_img /path/to/image.png</code>
|
||||
</div>
|
||||
)}
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
|
||||
@@ -68,6 +68,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
use_ema: false,
|
||||
ema_decay: 0.99,
|
||||
},
|
||||
skip_first_sample: false,
|
||||
disable_sampling: false,
|
||||
dtype: 'bf16',
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
|
||||
@@ -200,6 +200,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
||||
},
|
||||
].sort((a, b) => {
|
||||
// Sort by label, case-insensitive
|
||||
|
||||
@@ -110,6 +110,8 @@ export interface TrainConfig {
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
skip_first_sample: boolean;
|
||||
disable_sampling: boolean;
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
|
||||
Reference in New Issue
Block a user