From 73c8b5097549ecfdeb835f1a2e096807097f6ccb Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 24 Oct 2023 11:16:01 -0600 Subject: [PATCH] Added ability to use adagrad from transformers --- toolkit/optimizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index b4449f73..bcb15761 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -1,4 +1,5 @@ import torch +from transformers import Adafactor def get_optimizer( @@ -54,10 +55,15 @@ def get_optimizer( elif lower_type == 'adamw': optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params) elif lower_type == 'lion': - from lion_pytorch import Lion - return Lion(params, lr=learning_rate, **optimizer_params) + try: + from lion_pytorch import Lion + return Lion(params, lr=learning_rate, **optimizer_params) + except ImportError: + raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") elif lower_type == 'adagrad': optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'adafactor': + optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer