mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 23:17:22 +00:00
Added Differential Output Preservation Loss to trainer and ui
This commit is contained in:
@@ -61,6 +61,10 @@ export const defaultJobConfig: JobConfig = {
|
||||
ema_decay: 0.99,
|
||||
},
|
||||
dtype: 'bf16',
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person'
|
||||
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -275,7 +275,7 @@ export default function TrainingForm() {
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
@@ -384,6 +384,31 @@ export default function TrainingForm() {
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differtial Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="DFE Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DFE Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
@@ -100,6 +100,9 @@ export interface TrainConfig {
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
Reference in New Issue
Block a user