Added support for AdEMAMix8bit

This commit is contained in:
Jaret Burkett
2024-09-28 14:33:51 -06:00
parent a508caad1d
commit 69aa92bce5

View File

@@ -46,6 +46,8 @@ def get_optimizer(
if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
if lower_type == "ademamix8bit":
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
elif lower_type == "lion8bit":