diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index dfffae7c..ee4b6e72 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -379,7 +379,8 @@ class TrainConfig: self.do_prior_divergence = kwargs.get('do_prior_divergence', False) ema_config: Union[Dict, None] = kwargs.get('ema_config', None) - if ema_config is not None: + # if it is set explicitly to false, leave it false. + if ema_config is not None and ema_config.get('use_ema', None) is not None: ema_config['use_ema'] = True print(f"Using EMA") else: diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 38c1305c..dde4dbac 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -23,6 +23,7 @@ export const defaultJobConfig: JobConfig = { training_folder: 'output', sqlite_db_path: './aitk_db.db', device: 'cuda:0', + trigger_word: null, network: { type: 'lora', linear: 16, @@ -32,6 +33,7 @@ export const defaultJobConfig: JobConfig = { dtype: 'bf16', save_every: 250, max_step_saves_to_keep: 4, + save_format: 'diffusers', push_to_hub: false, }, datasets: [ @@ -47,6 +49,8 @@ export const defaultJobConfig: JobConfig = { gradient_checkpointing: true, noise_scheduler: 'flowmatch', optimizer: 'adamw8bit', + timestep_type: 'sigmoid', + content_or_style: 'balanced', optimizer_params: { weight_decay: 1e-4 }, diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index 527334ed..e539ef4f 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -28,7 +28,6 @@ export default function TrainingForm() { const { datasets, status: datasetFetchStatus } = useDatasetList(); const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); - const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); @@ -152,10 +151,21 @@ export default function TrainingForm() { setGpuIDs(value)} options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} /> + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'jobConfig.config.process[0].trigger_word'); + }} + placeholder="" + required + /> {/* Model Configuration Section */} @@ -191,7 +201,8 @@ export default function TrainingForm() { label: model.name_or_path, }))} /> - + +
setJobConfig(value, 'config.process[0].model.quantize_te')} /> +
{jobConfig.config.process[0].network?.type && ( @@ -256,7 +268,6 @@ export default function TrainingForm() {
setJobConfig(value, 'config.process[0].train.batch_size')} placeholder="eg. 4" @@ -285,7 +296,6 @@ export default function TrainingForm() {
setJobConfig(value, 'config.process[0].train.optimizer')} options={[ @@ -312,6 +322,54 @@ export default function TrainingForm() { required />
+
+ setJobConfig(value, 'config.process[0].train.timestep_type')} + options={[ + { value: 'sigmoid', label: 'Sigmoid' }, + { value: 'linear', label: 'Linear' }, + { value: 'flux_shift', label: 'Flux Shift' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.content_or_style')} + options={[ + { value: 'balanced', label: 'Balanced' }, + { value: 'content', label: 'High Noise' }, + { value: 'style', label: 'Low Noise' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.noise_scheduler')} + options={[{ value: 'flowmatch', label: 'FlowMatch' }]} + /> +
+
+ + setJobConfig(value, 'config.process[0].train.ema_config.use_ema')} + /> + + setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')} + placeholder="eg. 0.99" + min={0} + /> +
@@ -341,36 +399,13 @@ export default function TrainingForm() { onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} options={datasetOptions} /> - {/* setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} - placeholder="eg. /path/to/images/folder" - required - /> */} - {/* { - let setValue: string | null = value; - if (!setValue || setValue.trim() === '') { - setValue = null; - } - setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`); - }} - placeholder="eg. /path/to/masks/folder" - /> - setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)} - placeholder="eg. 0.1" - min={0} - max={1} - required - /> */} + setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)} + placeholder="eg. 1.0" + />
setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} placeholder="eg. A photo of a cat" /> - setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)} - placeholder="eg. txt" - required - /> setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) @@ -410,7 +437,6 @@ export default function TrainingForm() { /> setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} /> @@ -418,19 +444,28 @@ export default function TrainingForm() {
- {[256, 512, 768, 1024, 1280].map(res => ( - { - const resolutions = dataset.resolution.includes(res) - ? dataset.resolution.filter(r => r !== res) - : [...dataset.resolution, res]; - setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); - }} - /> - ))} +
+ {[ + [256, 512, 768], + [1024, 1280, 1536], + ].map(resGroup => ( +
+ {resGroup.map(res => ( + { + const resolutions = dataset.resolution.includes(res) + ? dataset.resolution.filter(r => r !== res) + : [...dataset.resolution, res]; + setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); + }} + /> + ))} +
+ ))} +
diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 64e1cf92..6b5bbcca 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -49,44 +49,55 @@ export interface NumberInputProps extends InputProps { export const NumberInput = (props: NumberInputProps) => { const { label, value, onChange, placeholder, required, min, max } = props; + + // Add controlled internal state to properly handle partial inputs + const [inputValue, setInputValue] = React.useState(value ?? ''); + + // Sync internal state with prop value + React.useEffect(() => { + setInputValue(value ?? ''); + }, [value]); + return (
{label && } { - // Use parseFloat instead of Number to properly handle decimal values const rawValue = e.target.value; + + // Update the input display with the raw value + setInputValue(rawValue); - // Special handling for empty or partial inputs - if (rawValue === '' || rawValue === '-' || rawValue === '.') { - // For empty or partial inputs (like just a minus sign or decimal point), - // we need to maintain the raw input in the input field - // but pass a valid number to onChange - onChange(0); + // Handle empty or partial inputs + if (rawValue === '' || rawValue === '-') { + // For empty or partial negative input, don't call onChange yet return; } - let value = Number(rawValue); + const numValue = Number(rawValue); - // Handle NaN cases - if (isNaN(value)) { - value = 0; + // Only apply constraints and call onChange when we have a valid number + if (!isNaN(numValue)) { + let constrainedValue = numValue; + + // Apply min/max constraints if they exist + if (min !== undefined && constrainedValue < min) { + constrainedValue = min; + } + if (max !== undefined && constrainedValue > max) { + constrainedValue = max; + } + + onChange(constrainedValue); } - - // Apply min/max constraints only for valid numbers - if (min !== undefined && value < min) value = min; - if (max !== undefined && value > max) value = max; - - onChange(value); }} className={inputClasses} placeholder={placeholder} required={required} min={min} max={max} - // Allow decimal points step="any" />
@@ -126,36 +137,43 @@ export interface CheckboxProps { export const Checkbox = (props: CheckboxProps) => { const { label, checked, onChange, required, disabled } = props; - const id = React.useId(); // Generate unique ID for label association + const id = React.useId(); return ( -
-
-
- onChange(e.target.checked)} - className="w-4 h-4 rounded border-gray-700 bg-gray-800 text-indigo-600 focus:ring-2 focus:ring-indigo-500 focus:ring-offset-1 focus:ring-offset-gray-900 cursor-pointer transition-colors" - required={required} - disabled={disabled} - /> -
- {label && ( -
- -
+
+
+ > + Toggle {label} + + + {label && ( + + )}
); }; diff --git a/ui/src/types.ts b/ui/src/types.ts index cc21aeb4..5d334128 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -59,6 +59,7 @@ export interface SaveConfig { dtype: string; save_every: number; max_step_saves_to_keep: number; + save_format: string; push_to_hub: boolean; } @@ -90,6 +91,8 @@ export interface TrainConfig { train_text_encoder: boolean; gradient_checkpointing: boolean; noise_scheduler: string; + timestep_type: string; + content_or_style: string; optimizer: string; lr: number; ema_config?: EMAConfig; @@ -129,6 +132,7 @@ export interface ProcessConfig { type: 'ui_trainer'; sqlite_db_path?: string; training_folder: string; + trigger_word: string | null; device: string; network?: NetworkConfig; save: SaveConfig;