mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
bug fixes
This commit is contained in:
@@ -120,18 +120,18 @@ class ConceptReplacer(BaseSDTrainProcess):
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
|
||||
@@ -733,7 +733,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
self.sd.unet.train()
|
||||
self.sd.unet.train()
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
|
||||
Reference in New Issue
Block a user