mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-30 18:39:50 +00:00
41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
import torch
|
|
from typing import Optional
|
|
|
|
|
|
def get_lr_scheduler(
|
|
name: Optional[str],
|
|
optimizer: torch.optim.Optimizer,
|
|
**kwargs,
|
|
):
|
|
if name == "cosine":
|
|
if 'total_iters' in kwargs:
|
|
kwargs['T_max'] = kwargs.pop('total_iters')
|
|
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "cosine_with_restarts":
|
|
if 'total_iters' in kwargs:
|
|
kwargs['T_0'] = kwargs.pop('total_iters')
|
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "step":
|
|
|
|
return torch.optim.lr_scheduler.StepLR(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "constant":
|
|
if 'facor' not in kwargs:
|
|
kwargs['factor'] = 1.0
|
|
|
|
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
|
|
elif name == "linear":
|
|
|
|
return torch.optim.lr_scheduler.LinearLR(
|
|
optimizer, **kwargs
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
|
)
|