Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it

This commit is contained in:
Jaret Burkett
2023-09-16 17:41:07 -06:00
parent 27f343fc08
commit c698837241
11 changed files with 214 additions and 78 deletions

View File

@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
real = real.float()
fake = fake.float()
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
if torch.isnan(interpolates).any():
print('d_interpolates is nan')
d_interpolates = critic(interpolates)
fake = torch.ones(real.size(0), 1, device=device)
if torch.isnan(d_interpolates).any():
print('fake is nan')
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
only_inputs=True,
)[0]
# see if any are nan
if torch.isnan(gradients).any():
print('gradients is nan')
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
return gradient_penalty.float()
class PatternLoss(torch.nn.Module):