mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Critic support to VAE training. Still tweaking and working on it. Many other fixes
This commit is contained in:
18
toolkit/optimizer.py
Normal file
18
toolkit/optimizer.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
optimizer_type='adam',
|
||||
learning_rate=1e-6
|
||||
):
|
||||
if optimizer_type == 'dadaptation':
|
||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
||||
elif optimizer_type == 'adam':
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
Reference in New Issue
Block a user