mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added my good ole pattern loss. God I love that thing, conv transpose pattern instantly wiped from vae
This commit is contained in:
@@ -6,13 +6,16 @@ def get_optimizer(
|
||||
optimizer_type='adam',
|
||||
learning_rate=1e-6
|
||||
):
|
||||
if optimizer_type == 'dadaptation':
|
||||
lower_type = optimizer_type.lower()
|
||||
if lower_type == 'dadaptation':
|
||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
||||
elif optimizer_type == 'adam':
|
||||
elif lower_type == 'adam':
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
||||
elif lower_type == 'adamw':
|
||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate))
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user