Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it

This commit is contained in:
Jaret Burkett
2023-09-16 17:41:07 -06:00
parent 27f343fc08
commit c698837241
11 changed files with 214 additions and 78 deletions

View File

@@ -154,28 +154,28 @@ class Critic:
# train critic here
self.model.train()
self.model.requires_grad_(True)
self.optimizer.zero_grad()
critic_losses = []
for i in range(self.num_critic_per_gen):
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs)
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
stacked_output = self.model(inputs).float()
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
self.optimizer.zero_grad()
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# avg loss
loss = np.mean(critic_losses)