diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5f1bacda..d1d6507c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -664,7 +664,7 @@ class SDTrainer(BaseSDTrainProcess): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 - elif self.dfe.version in [3, 4, 5, 6, 7]: + elif self.dfe.version in [3, 4, 5, 6, 7, 8]: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index e9d8bdd2..9eb1606d 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -808,8 +808,16 @@ class DiffusionFeatureExtractor6(nn.Module): class DiffusionFeatureExtractor7(nn.Module): - def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None): + def __init__( + self, + device=torch.device("cuda"), + dtype=torch.bfloat16, + vae=None, + sd=None, + partial_step: bool = False + ): super().__init__() + self.version = 7 self.sd_ref = weakref.ref(sd) if sd is not None else None pretrained_model_name = "google/tipsv2-b14-dpt" @@ -823,6 +831,7 @@ class DiffusionFeatureExtractor7(nn.Module): self.losses = {} self.log_every = 100 self.step = 0 + self.do_partial_step = partial_step def prepare_inputs(self, tensor_0_1: torch.Tensor): """ @@ -886,14 +895,31 @@ class DiffusionFeatureExtractor7(nn.Module): # expand shape to match noise_pred while len(tv.shape) < len(noise_pred.shape): tv = tv.unsqueeze(-1) - # min 0.001 - tv = torch.clamp(tv, min=0.001) - - # step latent - x0 = noisy_latents - tv * noise_pred - stepped_latents = x0 + with torch.no_grad(): + target_0_1 = (tensors + 1) / 2 # 0 to 1 + + if not self.do_partial_step: + # step latent + x0 = noisy_latents - tv * noise_pred + stepped_latents = x0 + # min 0.001 + tv = torch.clamp(tv, min=0.001) + else: + # step is random 0.05 to 0.02 + step = torch.rand_like(tv) * 0.03 + 0.02 + next_step = tv - step + next_step = torch.clamp(next_step, min=0.0) + stepped_latents = noisy_latents + (next_step - tv) * noise_pred + with torch.no_grad(): + # make a noisy target at next timestep + target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) + # add noise + target_latents = (1.0 - next_step) * target_latents + next_step * noise + target_n1p1 = self.sd_ref().decode_latents(target_latents) + target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1 + latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype) tensors_n1p1 = self.sd_ref().decode_latents(latents) @@ -904,10 +930,7 @@ class DiffusionFeatureExtractor7(nn.Module): dtype = self.model.dtype with torch.no_grad(): - target_img = tensors.to(device, dtype=dtype) - # go from -1 to 1 to 0 to 1 - target_img = (target_img + 1) / 2 - target = self.prepare_inputs(target_img) + target = self.prepare_inputs(target_0_1) target = self.model(target) pred_images = pred_images.to(device, dtype=dtype) @@ -929,6 +952,9 @@ class DiffusionFeatureExtractor7(nn.Module): total_loss = (depth_loss + normals_loss + segmentation_loss) / 3.0 + if self.do_partial_step: + total_loss = total_loss * 10.0 + if 'total' not in self.losses: self.losses['total'] = total_loss.item() else: @@ -963,6 +989,11 @@ class DiffusionFeatureExtractor7(nn.Module): return total_loss +class DiffusionFeatureExtractor8(DiffusionFeatureExtractor7): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None, sd=None): + super().__init__(device=device, dtype=dtype, vae=vae, sd=sd, partial_step=True) + self.version = 8 + def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureExtractor: if model_path == "v3": dfe = DiffusionFeatureExtractor3(vae=vae) @@ -984,6 +1015,10 @@ def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureEx dfe = DiffusionFeatureExtractor7(vae=vae, sd=sd) dfe.eval() return dfe + if model_path == "v8": + dfe = DiffusionFeatureExtractor8(vae=vae, sd=sd) + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors