Files
ai-toolkit/toolkit/scheduler.py
2023-09-03 20:38:44 -06:00

41 lines
1.1 KiB
Python

import torch
from typing import Optional
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
**kwargs,
):
if name == "cosine":
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
kwargs['T_0'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, **kwargs
)
elif name == "constant":
if 'facor' not in kwargs:
kwargs['factor'] = 1.0
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, **kwargs
)
else:
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)