Add stepped loss type

This commit is contained in:
Jaret Burkett
2025-09-22 15:50:12 -06:00
parent 28728a1e92
commit f74475161e
7 changed files with 108 additions and 46 deletions

View File

@@ -524,13 +524,15 @@ export default function SimpleJob({
]}
/>
<SelectInput
label="Noise Scheduler"
label="Loss Type"
className="pt-2"
value={jobConfig.config.process[0].train.noise_scheduler}
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
value={jobConfig.config.process[0].train.loss_type}
onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')}
options={[
{ value: 'flowmatch', label: 'FlowMatch' },
{ value: 'ddpm', label: 'DDPM' },
{ value: 'mse', label: 'Mean Squared Error' },
{ value: 'mae', label: 'Mean Absolute Error' },
{ value: 'wavelet', label: 'Wavelet' },
{ value: 'stepped', label: 'Stepped Recovery' },
]}
/>
</div>

View File

@@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = {
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person',
switch_boundary_every: 1,
loss_type: 'mse',
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -123,6 +123,7 @@ export interface TrainConfig {
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
switch_boundary_every: number;
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
}
export interface QuantizeKwargsConfig {