mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it
This commit is contained in:
@@ -154,28 +154,28 @@ class Critic:
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
stacked_output = self.model(inputs).float()
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
|
||||
Reference in New Issue
Block a user