mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow special args for schedulers
This commit is contained in:
@@ -553,13 +553,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
if 'max_iterations' not in lr_scheduler_params:
|
||||
lr_scheduler_params['total_iters'] = self.train_config.steps
|
||||
|
||||
lr_scheduler = get_lr_scheduler(
|
||||
self.train_config.lr_scheduler,
|
||||
optimizer,
|
||||
max_iterations=self.train_config.steps,
|
||||
lr_min=self.train_config.lr / 100,
|
||||
**lr_scheduler_params
|
||||
)
|
||||
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
### HOOK ###
|
||||
@@ -601,6 +605,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -71,7 +71,9 @@ class TrainConfig:
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
@@ -80,7 +82,6 @@ class TrainConfig:
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
|
||||
@@ -5,27 +5,34 @@ from typing import Optional
|
||||
def get_lr_scheduler(
|
||||
name: Optional[str],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
max_iterations: Optional[int],
|
||||
lr_min: Optional[float],
|
||||
**kwargs,
|
||||
):
|
||||
if name == "cosine":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_max'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "cosine_with_restarts":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_0'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "step":
|
||||
|
||||
return torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "constant":
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
||||
if 'facor' not in kwargs:
|
||||
kwargs['factor'] = 1.0
|
||||
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
|
||||
elif name == "linear":
|
||||
|
||||
return torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user