mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added more settings to the training config
This commit is contained in:
@@ -379,7 +379,8 @@ class TrainConfig:
|
||||
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
|
||||
|
||||
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
|
||||
if ema_config is not None:
|
||||
# if it is set explicitly to false, leave it false.
|
||||
if ema_config is not None and ema_config.get('use_ema', None) is not None:
|
||||
ema_config['use_ema'] = True
|
||||
print(f"Using EMA")
|
||||
else:
|
||||
|
||||
@@ -23,6 +23,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda:0',
|
||||
trigger_word: null,
|
||||
network: {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
@@ -32,6 +33,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
dtype: 'bf16',
|
||||
save_every: 250,
|
||||
max_step_saves_to_keep: 4,
|
||||
save_format: 'diffusers',
|
||||
push_to_hub: false,
|
||||
},
|
||||
datasets: [
|
||||
@@ -47,6 +49,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
gradient_checkpointing: true,
|
||||
noise_scheduler: 'flowmatch',
|
||||
optimizer: 'adamw8bit',
|
||||
timestep_type: 'sigmoid',
|
||||
content_or_style: 'balanced',
|
||||
optimizer_params: {
|
||||
weight_decay: 1e-4
|
||||
},
|
||||
|
||||
@@ -28,7 +28,6 @@ export default function TrainingForm() {
|
||||
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
||||
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
|
||||
|
||||
|
||||
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
@@ -152,10 +151,21 @@ export default function TrainingForm() {
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuIDs}`}
|
||||
className="pt-2"
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'jobConfig.config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
@@ -191,7 +201,8 @@ export default function TrainingForm() {
|
||||
label: model.name_or_path,
|
||||
}))}
|
||||
/>
|
||||
<FormGroup label="Quantize" className="pt-2">
|
||||
<FormGroup label="Quantize">
|
||||
<div className='grid grid-cols-2 gap-2'>
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
@@ -202,6 +213,7 @@ export default function TrainingForm() {
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
{jobConfig.config.process[0].network?.type && (
|
||||
@@ -256,7 +268,6 @@ export default function TrainingForm() {
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.batch_size}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
|
||||
placeholder="eg. 4"
|
||||
@@ -285,7 +296,6 @@ export default function TrainingForm() {
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Optimizer"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.optimizer}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
|
||||
options={[
|
||||
@@ -312,6 +322,54 @@ export default function TrainingForm() {
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'flux_shift', label: 'Flux Shift' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Timestep Bias"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.content_or_style}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
|
||||
options={[
|
||||
{ value: 'balanced', label: 'Balanced' },
|
||||
{ value: 'content', label: 'High Noise' },
|
||||
{ value: 'style', label: 'Low Noise' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Noise Scheduler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="EMA (Exponential Moving Average)">
|
||||
<Checkbox
|
||||
label="Use EMA"
|
||||
className='pt-1'
|
||||
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
@@ -341,36 +399,13 @@ export default function TrainingForm() {
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
/>
|
||||
{/* <TextInput
|
||||
label="Folder Path"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
placeholder="eg. /path/to/images/folder"
|
||||
required
|
||||
/> */}
|
||||
{/* <TextInput
|
||||
label="Mask Folder Path"
|
||||
className="pt-2"
|
||||
value={dataset.mask_path || ''}
|
||||
onChange={value => {
|
||||
let setValue: string | null = value;
|
||||
if (!setValue || setValue.trim() === '') {
|
||||
setValue = null;
|
||||
}
|
||||
setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`);
|
||||
}}
|
||||
placeholder="eg. /path/to/masks/folder"
|
||||
/>
|
||||
<NumberInput
|
||||
label="Mask Min Value"
|
||||
className="pt-2"
|
||||
value={dataset.mask_min_value}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)}
|
||||
placeholder="eg. 0.1"
|
||||
min={0}
|
||||
max={1}
|
||||
required
|
||||
/> */}
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
className="pt-2"
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
|
||||
placeholder="eg. 1.0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<TextInput
|
||||
@@ -379,14 +414,6 @@ export default function TrainingForm() {
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
|
||||
placeholder="eg. A photo of a cat"
|
||||
/>
|
||||
<TextInput
|
||||
label="Caption Extension"
|
||||
className="pt-2"
|
||||
value={dataset.caption_ext}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)}
|
||||
placeholder="eg. txt"
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Caption Dropout Rate"
|
||||
className="pt-2"
|
||||
@@ -402,7 +429,7 @@ export default function TrainingForm() {
|
||||
<div>
|
||||
<FormGroup label="Settings" className="">
|
||||
<Checkbox
|
||||
label="Cache Latents to Disk"
|
||||
label="Cache Latents"
|
||||
checked={dataset.cache_latents_to_disk || false}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
|
||||
@@ -410,7 +437,6 @@ export default function TrainingForm() {
|
||||
/>
|
||||
<Checkbox
|
||||
label="Is Regularization"
|
||||
className="pt-2"
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
@@ -418,19 +444,28 @@ export default function TrainingForm() {
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Resolutions" className="pt-2">
|
||||
{[256, 512, 768, 1024, 1280].map(res => (
|
||||
<Checkbox
|
||||
key={res}
|
||||
label={res.toString()}
|
||||
checked={dataset.resolution.includes(res)}
|
||||
onChange={value => {
|
||||
const resolutions = dataset.resolution.includes(res)
|
||||
? dataset.resolution.filter(r => r !== res)
|
||||
: [...dataset.resolution, res];
|
||||
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{[
|
||||
[256, 512, 768],
|
||||
[1024, 1280, 1536],
|
||||
].map(resGroup => (
|
||||
<div key={resGroup[0]} className="space-y-2">
|
||||
{resGroup.map(res => (
|
||||
<Checkbox
|
||||
key={res}
|
||||
label={res.toString()}
|
||||
checked={dataset.resolution.includes(res)}
|
||||
onChange={value => {
|
||||
const resolutions = dataset.resolution.includes(res)
|
||||
? dataset.resolution.filter(r => r !== res)
|
||||
: [...dataset.resolution, res];
|
||||
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -49,44 +49,55 @@ export interface NumberInputProps extends InputProps {
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
|
||||
// Sync internal state with prop value
|
||||
React.useEffect(() => {
|
||||
setInputValue(value ?? '');
|
||||
}, [value]);
|
||||
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<input
|
||||
type="number"
|
||||
value={value}
|
||||
value={inputValue}
|
||||
onChange={e => {
|
||||
// Use parseFloat instead of Number to properly handle decimal values
|
||||
const rawValue = e.target.value;
|
||||
|
||||
// Update the input display with the raw value
|
||||
setInputValue(rawValue);
|
||||
|
||||
// Special handling for empty or partial inputs
|
||||
if (rawValue === '' || rawValue === '-' || rawValue === '.') {
|
||||
// For empty or partial inputs (like just a minus sign or decimal point),
|
||||
// we need to maintain the raw input in the input field
|
||||
// but pass a valid number to onChange
|
||||
onChange(0);
|
||||
// Handle empty or partial inputs
|
||||
if (rawValue === '' || rawValue === '-') {
|
||||
// For empty or partial negative input, don't call onChange yet
|
||||
return;
|
||||
}
|
||||
|
||||
let value = Number(rawValue);
|
||||
const numValue = Number(rawValue);
|
||||
|
||||
// Handle NaN cases
|
||||
if (isNaN(value)) {
|
||||
value = 0;
|
||||
// Only apply constraints and call onChange when we have a valid number
|
||||
if (!isNaN(numValue)) {
|
||||
let constrainedValue = numValue;
|
||||
|
||||
// Apply min/max constraints if they exist
|
||||
if (min !== undefined && constrainedValue < min) {
|
||||
constrainedValue = min;
|
||||
}
|
||||
if (max !== undefined && constrainedValue > max) {
|
||||
constrainedValue = max;
|
||||
}
|
||||
|
||||
onChange(constrainedValue);
|
||||
}
|
||||
|
||||
// Apply min/max constraints only for valid numbers
|
||||
if (min !== undefined && value < min) value = min;
|
||||
if (max !== undefined && value > max) value = max;
|
||||
|
||||
onChange(value);
|
||||
}}
|
||||
className={inputClasses}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
min={min}
|
||||
max={max}
|
||||
// Allow decimal points
|
||||
step="any"
|
||||
/>
|
||||
</div>
|
||||
@@ -126,36 +137,43 @@ export interface CheckboxProps {
|
||||
|
||||
export const Checkbox = (props: CheckboxProps) => {
|
||||
const { label, checked, onChange, required, disabled } = props;
|
||||
const id = React.useId(); // Generate unique ID for label association
|
||||
const id = React.useId();
|
||||
|
||||
return (
|
||||
<div className={classNames('flex items-center', props.className)}>
|
||||
<div className="relative flex items-start">
|
||||
<div className="flex items-center h-5">
|
||||
<input
|
||||
id={id}
|
||||
type="checkbox"
|
||||
checked={checked}
|
||||
onChange={e => onChange(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-gray-700 bg-gray-800 text-indigo-600 focus:ring-2 focus:ring-indigo-500 focus:ring-offset-1 focus:ring-offset-gray-900 cursor-pointer transition-colors"
|
||||
required={required}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
{label && (
|
||||
<div className="ml-3 text-sm">
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
<div className={classNames('flex items-center gap-3', props.className)}>
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
id={id}
|
||||
aria-checked={checked}
|
||||
aria-required={required}
|
||||
disabled={disabled}
|
||||
onClick={() => !disabled && onChange(!checked)}
|
||||
className={classNames(
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
|
||||
checked ? 'bg-blue-600' : 'bg-gray-700',
|
||||
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
|
||||
)}
|
||||
</div>
|
||||
>
|
||||
<span className="sr-only">Toggle {label}</span>
|
||||
<span
|
||||
className={classNames(
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
checked ? 'translate-x-5' : 'translate-x-0'
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
{label && (
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'text-sm font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300'
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -59,6 +59,7 @@ export interface SaveConfig {
|
||||
dtype: string;
|
||||
save_every: number;
|
||||
max_step_saves_to_keep: number;
|
||||
save_format: string;
|
||||
push_to_hub: boolean;
|
||||
}
|
||||
|
||||
@@ -90,6 +91,8 @@ export interface TrainConfig {
|
||||
train_text_encoder: boolean;
|
||||
gradient_checkpointing: boolean;
|
||||
noise_scheduler: string;
|
||||
timestep_type: string;
|
||||
content_or_style: string;
|
||||
optimizer: string;
|
||||
lr: number;
|
||||
ema_config?: EMAConfig;
|
||||
@@ -129,6 +132,7 @@ export interface ProcessConfig {
|
||||
type: 'ui_trainer';
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
trigger_word: string | null;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
save: SaveConfig;
|
||||
|
||||
Reference in New Issue
Block a user