Files
ai-toolkit/toolkit/optimizer.py
2023-08-01 13:49:54 -06:00

62 lines
2.7 KiB
Python

import torch
def get_optimizer(
params,
optimizer_type='adam',
learning_rate=1e-6,
optimizer_params=None
):
if optimizer_params is None:
optimizer_params = {}
lower_type = optimizer_type.lower()
if lower_type.startswith("dadaptation"):
# dadaptation optimizer does not use standard learning rate. 1 is the default value
import dadaptation
print("Using DAdaptAdam optimizer")
use_lr = learning_rate
if use_lr < 0.1:
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
use_lr = 1.0
if lower_type.endswith('lion'):
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
elif lower_type.endswith('adam'):
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
elif lower_type == 'dadaptation':
# backwards compatibility
optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params)
# warn user that dadaptation is deprecated
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
elif lower_type.startswith("prodigy"):
from prodigyopt import Prodigy
print("Using Prodigy optimizer")
use_lr = learning_rate
if use_lr < 0.1:
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
use_lr = 1.0
# let net be the neural network you want to train
# you can choose weight decay value based on your problem, 0 by default
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
elif lower_type.endswith("8bit"):
import bitsandbytes
if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
elif lower_type == 'adam':
optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'adamw':
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'lion':
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer