mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Allow special args for schedulers
This commit is contained in:
@@ -71,7 +71,9 @@ class TrainConfig:
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
@@ -80,7 +82,6 @@ class TrainConfig:
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
|
||||
Reference in New Issue
Block a user