diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index d2a1e925..4ec38501 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -46,6 +46,8 @@ def get_optimizer( if lower_type == "adam8bit": return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + if lower_type == "ademamix8bit": + return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) elif lower_type == "adamw8bit": return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) elif lower_type == "lion8bit":