mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adjustments to defaults for automagic
This commit is contained in:
@@ -13,8 +13,9 @@ class Automagic(torch.optim.Optimizer):
|
||||
params,
|
||||
lr=None,
|
||||
min_lr=1e-7,
|
||||
max_lr=1e-2,
|
||||
lr_momentum=0.9,
|
||||
max_lr=1e-3,
|
||||
lr_pump_scale=1.1,
|
||||
lr_dump_scale=0.85,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
@@ -25,7 +26,8 @@ class Automagic(torch.optim.Optimizer):
|
||||
self.lr = lr
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.lr_momentum = lr_momentum
|
||||
self.lr_pump_scale = lr_pump_scale
|
||||
self.lr_dump_scale = lr_dump_scale
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
@@ -249,8 +251,8 @@ class Automagic(torch.optim.Optimizer):
|
||||
# Update learning rate mask based on sign agreement
|
||||
new_lr = torch.where(
|
||||
sign_agreement > 0,
|
||||
lr_mask / self.lr_momentum, # Increase lr
|
||||
lr_mask * self.lr_momentum # Decrease lr
|
||||
lr_mask * self.lr_pump_scale, # Increase lr
|
||||
lr_mask * self.lr_dump_scale # Decrease lr
|
||||
)
|
||||
|
||||
# Clip learning rates to bounds
|
||||
|
||||
Reference in New Issue
Block a user