mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Various code to support experiments.
This commit is contained in:
@@ -438,7 +438,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version == 3:
|
||||
elif self.dfe.version == 3 or self.dfe.version == 4:
|
||||
dfe_loss = self.dfe(
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
@@ -518,7 +518,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
v2=self.train_config.linear_timesteps2,
|
||||
timestep_type=self.train_config.timestep_type
|
||||
).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
||||
if len(loss.shape) == 4:
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
||||
elif len(loss.shape) == 5:
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach()
|
||||
loss = loss * timestep_weight
|
||||
|
||||
if self.train_config.do_prior_divergence and prior_pred is not None:
|
||||
|
||||
Reference in New Issue
Block a user