mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Adjustments to defaults for automagic
This commit is contained in:
@@ -13,8 +13,9 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
params,
|
params,
|
||||||
lr=None,
|
lr=None,
|
||||||
min_lr=1e-7,
|
min_lr=1e-7,
|
||||||
max_lr=1e-2,
|
max_lr=1e-3,
|
||||||
lr_momentum=0.9,
|
lr_pump_scale=1.1,
|
||||||
|
lr_dump_scale=0.85,
|
||||||
eps=(1e-30, 1e-3),
|
eps=(1e-30, 1e-3),
|
||||||
clip_threshold=1.0,
|
clip_threshold=1.0,
|
||||||
decay_rate=-0.8,
|
decay_rate=-0.8,
|
||||||
@@ -25,7 +26,8 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
self.lr = lr
|
self.lr = lr
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
self.max_lr = max_lr
|
self.max_lr = max_lr
|
||||||
self.lr_momentum = lr_momentum
|
self.lr_pump_scale = lr_pump_scale
|
||||||
|
self.lr_dump_scale = lr_dump_scale
|
||||||
|
|
||||||
defaults = {
|
defaults = {
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
@@ -249,8 +251,8 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
# Update learning rate mask based on sign agreement
|
# Update learning rate mask based on sign agreement
|
||||||
new_lr = torch.where(
|
new_lr = torch.where(
|
||||||
sign_agreement > 0,
|
sign_agreement > 0,
|
||||||
lr_mask / self.lr_momentum, # Increase lr
|
lr_mask * self.lr_pump_scale, # Increase lr
|
||||||
lr_mask * self.lr_momentum # Decrease lr
|
lr_mask * self.lr_dump_scale # Decrease lr
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clip learning rates to bounds
|
# Clip learning rates to bounds
|
||||||
|
|||||||
Reference in New Issue
Block a user