Added Differential Output Preservation Loss to trainer and ui

This commit is contained in:
Jaret Burkett
2025-02-25 20:12:36 -07:00
parent 259ded9602
commit f6e16e582a
6 changed files with 127 additions and 6 deletions

View File

@@ -61,6 +61,10 @@ export const defaultJobConfig: JobConfig = {
ema_decay: 0.99,
},
dtype: 'bf16',
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person'
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -275,7 +275,7 @@ export default function TrainingForm() {
</div>
<div>
<Card title="Training Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
<div>
<NumberInput
label="Batch Size"
@@ -384,6 +384,31 @@ export default function TrainingForm() {
min={0}
/>
</div>
<div>
<FormGroup label="Regularization">
<Checkbox
label="Differtial Output Preservation"
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
<NumberInput
label="DFE Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DFE Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
</div>
</div>
</Card>
</div>

View File

@@ -100,6 +100,9 @@ export interface TrainConfig {
optimizer_params: {
weight_decay: number;
};
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
}
export interface QuantizeKwargsConfig {