Added ability to use adagrad from transformers

This commit is contained in:
Jaret Burkett
2023-10-24 11:16:01 -06:00
parent 34eb563d55
commit 73c8b50975

View File

@@ -1,4 +1,5 @@
import torch
from transformers import Adafactor
def get_optimizer(
@@ -54,10 +55,15 @@ def get_optimizer(
elif lower_type == 'adamw':
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'lion':
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
try:
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
except ImportError:
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'adafactor':
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer