mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
bug fixes
This commit is contained in:
@@ -120,18 +120,18 @@ class ConceptReplacer(BaseSDTrainProcess):
|
|||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||||
# add min_snr_gamma
|
# add min_snr_gamma
|
||||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
# back propagate loss to free ram
|
# back propagate loss to free ram
|
||||||
loss.backward()
|
loss.backward()
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
# apply gradients
|
# apply gradients
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
self.sd.text_encoder.train()
|
self.sd.text_encoder.train()
|
||||||
|
|
||||||
self.sd.unet.train()
|
self.sd.unet.train()
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
loss_dict = self.hook_train_loop(batch)
|
loss_dict = self.hook_train_loop(batch)
|
||||||
flush()
|
flush()
|
||||||
|
|||||||
Reference in New Issue
Block a user