diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index b4449f73..bcb15761 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -1,4 +1,5 @@ import torch +from transformers import Adafactor def get_optimizer( @@ -54,10 +55,15 @@ def get_optimizer( elif lower_type == 'adamw': optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params) elif lower_type == 'lion': - from lion_pytorch import Lion - return Lion(params, lr=learning_rate, **optimizer_params) + try: + from lion_pytorch import Lion + return Lion(params, lr=learning_rate, **optimizer_params) + except ImportError: + raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") elif lower_type == 'adagrad': optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'adafactor': + optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer