Added Critic support to VAE training. Still tweaking and working on it. Many other fixes

This commit is contained in:
Jaret Burkett
2023-07-19 15:57:32 -06:00
parent 6ada328d8d
commit 557732e7ff
9 changed files with 415 additions and 59 deletions

18
toolkit/optimizer.py Normal file
View File

@@ -0,0 +1,18 @@
import torch
def get_optimizer(
params,
optimizer_type='adam',
learning_rate=1e-6
):
if optimizer_type == 'dadaptation':
# dadaptation optimizer does not use standard learning rate. 1 is the default value
import dadaptation
print("Using DAdaptAdam optimizer")
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
elif optimizer_type == 'adam':
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer