mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Add stepped loss type
This commit is contained in:
@@ -34,7 +34,7 @@ from diffusers import EMAModel
|
||||
import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
from toolkit.util.losses import wavelet_loss, stepped_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.unloader import unload_text_encoder
|
||||
from PIL import Image
|
||||
@@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
elif self.train_config.loss_type == "wavelet":
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
elif self.train_config.loss_type == "stepped":
|
||||
loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler)
|
||||
# the way this loss works, it is low, increase it to match predictable LR effects
|
||||
loss = loss * 10.0
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user