From 22ed539321c5df5afd8d24015e1db5dd107a6908 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 3 Sep 2023 20:38:44 -0600 Subject: [PATCH] Allow special args for schedulers --- jobs/process/BaseSDTrainProcess.py | 12 +++++++++--- toolkit/config_modules.py | 3 ++- toolkit/scheduler.py | 21 ++++++++++++++------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0f8c23a6..068b6cd5 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -553,13 +553,17 @@ class BaseSDTrainProcess(BaseTrainProcess): optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer + lr_scheduler_params = self.train_config.lr_scheduler_params + + # make sure it had bare minimum + if 'max_iterations' not in lr_scheduler_params: + lr_scheduler_params['total_iters'] = self.train_config.steps + lr_scheduler = get_lr_scheduler( self.train_config.lr_scheduler, optimizer, - max_iterations=self.train_config.steps, - lr_min=self.train_config.lr / 100, + **lr_scheduler_params ) - self.lr_scheduler = lr_scheduler ### HOOK ### @@ -601,6 +605,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # zero any gradients optimizer.zero_grad() + self.lr_scheduler.step(self.step_num) + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): with torch.no_grad(): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b1e6cd81..3c44d31a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -71,7 +71,9 @@ class TrainConfig: self.unet_lr = kwargs.get('unet_lr', self.lr) self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) self.optimizer = kwargs.get('optimizer', 'adamw') + self.optimizer_params = kwargs.get('optimizer_params', {}) self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') + self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50) self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') @@ -80,7 +82,6 @@ class TrainConfig: self.train_text_encoder = kwargs.get('train_text_encoder', True) self.min_snr_gamma = kwargs.get('min_snr_gamma', None) self.noise_offset = kwargs.get('noise_offset', 0.0) - self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.weight_jitter = kwargs.get('weight_jitter', 0.0) diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py index ab5558a5..a6b97d8d 100644 --- a/toolkit/scheduler.py +++ b/toolkit/scheduler.py @@ -5,27 +5,34 @@ from typing import Optional def get_lr_scheduler( name: Optional[str], optimizer: torch.optim.Optimizer, - max_iterations: Optional[int], - lr_min: Optional[float], **kwargs, ): if name == "cosine": + if 'total_iters' in kwargs: + kwargs['T_max'] = kwargs.pop('total_iters') return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs + optimizer, **kwargs ) elif name == "cosine_with_restarts": + if 'total_iters' in kwargs: + kwargs['T_0'] = kwargs.pop('total_iters') return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs + optimizer, **kwargs ) elif name == "step": + return torch.optim.lr_scheduler.StepLR( - optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs + optimizer, **kwargs ) elif name == "constant": - return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) + if 'facor' not in kwargs: + kwargs['factor'] = 1.0 + + return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs) elif name == "linear": + return torch.optim.lr_scheduler.LinearLR( - optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs + optimizer, **kwargs ) else: raise ValueError(