mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added ability to use adagrad from transformers
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from transformers import Adafactor
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
@@ -54,10 +55,15 @@ def get_optimizer(
|
|||||||
elif lower_type == 'adamw':
|
elif lower_type == 'adamw':
|
||||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
||||||
elif lower_type == 'lion':
|
elif lower_type == 'lion':
|
||||||
from lion_pytorch import Lion
|
try:
|
||||||
return Lion(params, lr=learning_rate, **optimizer_params)
|
from lion_pytorch import Lion
|
||||||
|
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
||||||
elif lower_type == 'adagrad':
|
elif lower_type == 'adagrad':
|
||||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||||
|
elif lower_type == 'adafactor':
|
||||||
|
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user