Allow special args for schedulers

This commit is contained in:
Jaret Burkett
2023-09-03 20:38:44 -06:00
parent 7cd6945082
commit 22ed539321
3 changed files with 25 additions and 11 deletions

View File

@@ -5,27 +5,34 @@ from typing import Optional
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
max_iterations: Optional[int],
lr_min: Optional[float],
**kwargs,
):
if name == "cosine":
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
kwargs['T_0'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
optimizer, **kwargs
)
elif name == "constant":
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
if 'facor' not in kwargs:
kwargs['factor'] = 1.0
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
optimizer, **kwargs
)
else:
raise ValueError(