bug fixes

This commit is contained in:
Jaret Burkett
2023-09-09 15:04:44 -06:00
parent 2128ac1e08
commit 626ed2939a
2 changed files with 10 additions and 10 deletions

View File

@@ -120,18 +120,18 @@ class ConceptReplacer(BaseSDTrainProcess):
guidance_scale=1.0,
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
flush()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
self.optimizer.step()

View File

@@ -733,7 +733,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
self.sd.text_encoder.train()
self.sd.unet.train()
self.sd.unet.train()
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()