Allow special args for schedulers

This commit is contained in:
Jaret Burkett
2023-09-03 20:38:44 -06:00
parent 7cd6945082
commit 22ed539321
3 changed files with 25 additions and 11 deletions

View File

@@ -71,7 +71,9 @@ class TrainConfig:
self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
@@ -80,7 +82,6 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)