mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it
This commit is contained in:
@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
|
||||
# Gradient penalty
|
||||
def get_gradient_penalty(critic, real, fake, device):
|
||||
with torch.autocast(device_type='cuda'):
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
|
||||
real = real.float()
|
||||
fake = fake.float()
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
|
||||
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
|
||||
if torch.isnan(interpolates).any():
|
||||
print('d_interpolates is nan')
|
||||
d_interpolates = critic(interpolates)
|
||||
fake = torch.ones(real.size(0), 1, device=device)
|
||||
|
||||
|
||||
if torch.isnan(d_interpolates).any():
|
||||
print('fake is nan')
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=d_interpolates,
|
||||
inputs=interpolates,
|
||||
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
# see if any are nan
|
||||
if torch.isnan(gradients).any():
|
||||
print('gradients is nan')
|
||||
|
||||
gradients = gradients.view(gradients.size(0), -1)
|
||||
gradient_norm = gradients.norm(2, dim=1)
|
||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||
return gradient_penalty
|
||||
return gradient_penalty.float()
|
||||
|
||||
|
||||
class PatternLoss(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user