mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Improvements to vae trainer. Adjust denoise prediction of DFE v3
This commit is contained in:
@@ -220,10 +220,15 @@ class Critic:
|
||||
|
||||
return float(np.mean(critic_losses))
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
return (
|
||||
self.optimizer.param_groups[0]["d"]
|
||||
* self.optimizer.param_groups[0]["lr"]
|
||||
def get_lr(self):
|
||||
if hasattr(self.optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = self.optimizer.get_avg_learning_rate()
|
||||
elif self.optimizer_type.startswith('dadaptation') or \
|
||||
self.optimizer_type.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
return self.optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
return learning_rate
|
||||
|
||||
Reference in New Issue
Block a user