Added more settings to the training config

This commit is contained in:
Jaret Burkett
2025-02-23 12:34:52 -07:00
parent a280f78c69
commit b366e46f1c
5 changed files with 167 additions and 105 deletions

View File

@@ -379,7 +379,8 @@ class TrainConfig:
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
if ema_config is not None:
# if it is set explicitly to false, leave it false.
if ema_config is not None and ema_config.get('use_ema', None) is not None:
ema_config['use_ema'] = True
print(f"Using EMA")
else:

View File

@@ -23,6 +23,7 @@ export const defaultJobConfig: JobConfig = {
training_folder: 'output',
sqlite_db_path: './aitk_db.db',
device: 'cuda:0',
trigger_word: null,
network: {
type: 'lora',
linear: 16,
@@ -32,6 +33,7 @@ export const defaultJobConfig: JobConfig = {
dtype: 'bf16',
save_every: 250,
max_step_saves_to_keep: 4,
save_format: 'diffusers',
push_to_hub: false,
},
datasets: [
@@ -47,6 +49,8 @@ export const defaultJobConfig: JobConfig = {
gradient_checkpointing: true,
noise_scheduler: 'flowmatch',
optimizer: 'adamw8bit',
timestep_type: 'sigmoid',
content_or_style: 'balanced',
optimizer_params: {
weight_decay: 1e-4
},

View File

@@ -28,7 +28,6 @@ export default function TrainingForm() {
const { datasets, status: datasetFetchStatus } = useDatasetList();
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
@@ -152,10 +151,21 @@ export default function TrainingForm() {
<SelectInput
label="GPU ID"
value={`${gpuIDs}`}
className="pt-2"
onChange={value => setGpuIDs(value)}
options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
<TextInput
label="Trigger Word"
value={jobConfig.config.process[0].trigger_word || ''}
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'jobConfig.config.process[0].trigger_word');
}}
placeholder=""
required
/>
</Card>
{/* Model Configuration Section */}
@@ -191,7 +201,8 @@ export default function TrainingForm() {
label: model.name_or_path,
}))}
/>
<FormGroup label="Quantize" className="pt-2">
<FormGroup label="Quantize">
<div className='grid grid-cols-2 gap-2'>
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
@@ -202,6 +213,7 @@ export default function TrainingForm() {
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.type && (
@@ -256,7 +268,6 @@ export default function TrainingForm() {
<div>
<NumberInput
label="Batch Size"
className="pt-2"
value={jobConfig.config.process[0].train.batch_size}
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
placeholder="eg. 4"
@@ -285,7 +296,6 @@ export default function TrainingForm() {
<div>
<SelectInput
label="Optimizer"
className="pt-2"
value={jobConfig.config.process[0].train.optimizer}
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
options={[
@@ -312,6 +322,54 @@ export default function TrainingForm() {
required
/>
</div>
<div>
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'flux_shift', label: 'Flux Shift' },
]}
/>
<SelectInput
label="Timestep Bias"
className="pt-2"
value={jobConfig.config.process[0].train.content_or_style}
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
options={[
{ value: 'balanced', label: 'Balanced' },
{ value: 'content', label: 'High Noise' },
{ value: 'style', label: 'Low Noise' },
]}
/>
<SelectInput
label="Noise Scheduler"
className="pt-2"
value={jobConfig.config.process[0].train.noise_scheduler}
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
/>
</div>
<div>
<FormGroup label="EMA (Exponential Moving Average)">
<Checkbox
label="Use EMA"
className='pt-1'
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/>
</FormGroup>
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
</div>
</div>
</Card>
</div>
@@ -341,36 +399,13 @@ export default function TrainingForm() {
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
options={datasetOptions}
/>
{/* <TextInput
label="Folder Path"
value={dataset.folder_path}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
placeholder="eg. /path/to/images/folder"
required
/> */}
{/* <TextInput
label="Mask Folder Path"
className="pt-2"
value={dataset.mask_path || ''}
onChange={value => {
let setValue: string | null = value;
if (!setValue || setValue.trim() === '') {
setValue = null;
}
setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`);
}}
placeholder="eg. /path/to/masks/folder"
/>
<NumberInput
label="Mask Min Value"
className="pt-2"
value={dataset.mask_min_value}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)}
placeholder="eg. 0.1"
min={0}
max={1}
required
/> */}
<NumberInput
label="LoRA Weight"
value={dataset.network_weight}
className="pt-2"
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
placeholder="eg. 1.0"
/>
</div>
<div>
<TextInput
@@ -379,14 +414,6 @@ export default function TrainingForm() {
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
placeholder="eg. A photo of a cat"
/>
<TextInput
label="Caption Extension"
className="pt-2"
value={dataset.caption_ext}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)}
placeholder="eg. txt"
required
/>
<NumberInput
label="Caption Dropout Rate"
className="pt-2"
@@ -402,7 +429,7 @@ export default function TrainingForm() {
<div>
<FormGroup label="Settings" className="">
<Checkbox
label="Cache Latents to Disk"
label="Cache Latents"
checked={dataset.cache_latents_to_disk || false}
onChange={value =>
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
@@ -410,7 +437,6 @@ export default function TrainingForm() {
/>
<Checkbox
label="Is Regularization"
className="pt-2"
checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/>
@@ -418,19 +444,28 @@ export default function TrainingForm() {
</div>
<div>
<FormGroup label="Resolutions" className="pt-2">
{[256, 512, 768, 1024, 1280].map(res => (
<Checkbox
key={res}
label={res.toString()}
checked={dataset.resolution.includes(res)}
onChange={value => {
const resolutions = dataset.resolution.includes(res)
? dataset.resolution.filter(r => r !== res)
: [...dataset.resolution, res];
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
}}
/>
))}
<div className="grid grid-cols-2 gap-2">
{[
[256, 512, 768],
[1024, 1280, 1536],
].map(resGroup => (
<div key={resGroup[0]} className="space-y-2">
{resGroup.map(res => (
<Checkbox
key={res}
label={res.toString()}
checked={dataset.resolution.includes(res)}
onChange={value => {
const resolutions = dataset.resolution.includes(res)
? dataset.resolution.filter(r => r !== res)
: [...dataset.resolution, res];
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
}}
/>
))}
</div>
))}
</div>
</FormGroup>
</div>
</div>

View File

@@ -49,44 +49,55 @@ export interface NumberInputProps extends InputProps {
export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max } = props;
// Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
// Sync internal state with prop value
React.useEffect(() => {
setInputValue(value ?? '');
}, [value]);
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<input
type="number"
value={value}
value={inputValue}
onChange={e => {
// Use parseFloat instead of Number to properly handle decimal values
const rawValue = e.target.value;
// Update the input display with the raw value
setInputValue(rawValue);
// Special handling for empty or partial inputs
if (rawValue === '' || rawValue === '-' || rawValue === '.') {
// For empty or partial inputs (like just a minus sign or decimal point),
// we need to maintain the raw input in the input field
// but pass a valid number to onChange
onChange(0);
// Handle empty or partial inputs
if (rawValue === '' || rawValue === '-') {
// For empty or partial negative input, don't call onChange yet
return;
}
let value = Number(rawValue);
const numValue = Number(rawValue);
// Handle NaN cases
if (isNaN(value)) {
value = 0;
// Only apply constraints and call onChange when we have a valid number
if (!isNaN(numValue)) {
let constrainedValue = numValue;
// Apply min/max constraints if they exist
if (min !== undefined && constrainedValue < min) {
constrainedValue = min;
}
if (max !== undefined && constrainedValue > max) {
constrainedValue = max;
}
onChange(constrainedValue);
}
// Apply min/max constraints only for valid numbers
if (min !== undefined && value < min) value = min;
if (max !== undefined && value > max) value = max;
onChange(value);
}}
className={inputClasses}
placeholder={placeholder}
required={required}
min={min}
max={max}
// Allow decimal points
step="any"
/>
</div>
@@ -126,36 +137,43 @@ export interface CheckboxProps {
export const Checkbox = (props: CheckboxProps) => {
const { label, checked, onChange, required, disabled } = props;
const id = React.useId(); // Generate unique ID for label association
const id = React.useId();
return (
<div className={classNames('flex items-center', props.className)}>
<div className="relative flex items-start">
<div className="flex items-center h-5">
<input
id={id}
type="checkbox"
checked={checked}
onChange={e => onChange(e.target.checked)}
className="w-4 h-4 rounded border-gray-700 bg-gray-800 text-indigo-600 focus:ring-2 focus:ring-indigo-500 focus:ring-offset-1 focus:ring-offset-gray-900 cursor-pointer transition-colors"
required={required}
disabled={disabled}
/>
</div>
{label && (
<div className="ml-3 text-sm">
<label
htmlFor={id}
className={classNames(
'font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300',
)}
>
{label}
</label>
</div>
<div className={classNames('flex items-center gap-3', props.className)}>
<button
type="button"
role="switch"
id={id}
aria-checked={checked}
aria-required={required}
disabled={disabled}
onClick={() => !disabled && onChange(!checked)}
className={classNames(
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
checked ? 'bg-blue-600' : 'bg-gray-700',
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
)}
</div>
>
<span className="sr-only">Toggle {label}</span>
<span
className={classNames(
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
checked ? 'translate-x-5' : 'translate-x-0'
)}
/>
</button>
{label && (
<label
htmlFor={id}
className={classNames(
'text-sm font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300'
)}
>
{label}
</label>
)}
</div>
);
};

View File

@@ -59,6 +59,7 @@ export interface SaveConfig {
dtype: string;
save_every: number;
max_step_saves_to_keep: number;
save_format: string;
push_to_hub: boolean;
}
@@ -90,6 +91,8 @@ export interface TrainConfig {
train_text_encoder: boolean;
gradient_checkpointing: boolean;
noise_scheduler: string;
timestep_type: string;
content_or_style: string;
optimizer: string;
lr: number;
ema_config?: EMAConfig;
@@ -129,6 +132,7 @@ export interface ProcessConfig {
type: 'ui_trainer';
sqlite_db_path?: string;
training_folder: string;
trigger_word: string | null;
device: string;
network?: NetworkConfig;
save: SaveConfig;