mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Allow special args for schedulers
This commit is contained in:
@@ -5,27 +5,34 @@ from typing import Optional
|
||||
def get_lr_scheduler(
|
||||
name: Optional[str],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
max_iterations: Optional[int],
|
||||
lr_min: Optional[float],
|
||||
**kwargs,
|
||||
):
|
||||
if name == "cosine":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_max'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "cosine_with_restarts":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_0'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "step":
|
||||
|
||||
return torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "constant":
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
||||
if 'facor' not in kwargs:
|
||||
kwargs['factor'] = 1.0
|
||||
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
|
||||
elif name == "linear":
|
||||
|
||||
return torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user