mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add stepped loss type
This commit is contained in:
@@ -524,13 +524,15 @@ export default function SimpleJob({
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Noise Scheduler"
|
||||
label="Loss Type"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||
value={jobConfig.config.process[0].train.loss_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')}
|
||||
options={[
|
||||
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
{ value: 'mse', label: 'Mean Squared Error' },
|
||||
{ value: 'mae', label: 'Mean Absolute Error' },
|
||||
{ value: 'wavelet', label: 'Wavelet' },
|
||||
{ value: 'stepped', label: 'Stepped Recovery' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person',
|
||||
switch_boundary_every: 1,
|
||||
loss_type: 'mse',
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -123,6 +123,7 @@ export interface TrainConfig {
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
switch_boundary_every: number;
|
||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
Reference in New Issue
Block a user