Allow special args for schedulers

This commit is contained in:
Jaret Burkett
2023-09-03 20:38:44 -06:00
parent 7cd6945082
commit 22ed539321
3 changed files with 25 additions and 11 deletions

View File

@@ -553,13 +553,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum
if 'max_iterations' not in lr_scheduler_params:
lr_scheduler_params['total_iters'] = self.train_config.steps
lr_scheduler = get_lr_scheduler(
self.train_config.lr_scheduler,
optimizer,
max_iterations=self.train_config.steps,
lr_min=self.train_config.lr / 100,
**lr_scheduler_params
)
self.lr_scheduler = lr_scheduler
### HOOK ###
@@ -601,6 +605,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
self.lr_scheduler.step(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
with torch.no_grad():

View File

@@ -71,7 +71,9 @@ class TrainConfig:
self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
@@ -80,7 +82,6 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)

View File

@@ -5,27 +5,34 @@ from typing import Optional
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
max_iterations: Optional[int],
lr_min: Optional[float],
**kwargs,
):
if name == "cosine":
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
kwargs['T_0'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
optimizer, **kwargs
)
elif name == "constant":
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
if 'facor' not in kwargs:
kwargs['factor'] = 1.0
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
optimizer, **kwargs
)
else:
raise ValueError(