From 626ed2939a01da9e357035d73403b17e8b1be1c7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 9 Sep 2023 15:04:44 -0600 Subject: [PATCH] bug fixes --- .../concept_replacer/ConceptReplacer.py | 18 +++++++++--------- jobs/process/BaseSDTrainProcess.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/extensions_built_in/concept_replacer/ConceptReplacer.py b/extensions_built_in/concept_replacer/ConceptReplacer.py index 04d4d42d..1600e8e1 100644 --- a/extensions_built_in/concept_replacer/ConceptReplacer.py +++ b/extensions_built_in/concept_replacer/ConceptReplacer.py @@ -120,18 +120,18 @@ class ConceptReplacer(BaseSDTrainProcess): guidance_scale=1.0, ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) - loss = loss.mean() + loss = loss.mean() - # back propagate loss to free ram - loss.backward() - flush() + # back propagate loss to free ram + loss.backward() + flush() # apply gradients self.optimizer.step() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index cbe6312b..45e77dcc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -733,7 +733,7 @@ class BaseSDTrainProcess(BaseTrainProcess): else: self.sd.text_encoder.train() - self.sd.unet.train() + self.sd.unet.train() ### HOOK ### loss_dict = self.hook_train_loop(batch) flush()