diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3001ea3c..532a1c8d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -577,7 +577,7 @@ class SDTrainer(BaseSDTrainProcess): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 - elif self.dfe.version == 3 or self.dfe.version == 4: + elif self.dfe.version in [3, 4, 5]: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 78a85c78..7774af29 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -470,10 +470,31 @@ class DiffusionFeatureExtractor4(nn.Module): output_hidden_states=True, ) - # embeds = id_embeds['hidden_states'][-2] # penultimate layer - image_embeds = id_embeds['pooler_output'] - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + image_embeds = id_embeds['hidden_states'][-2] # penultimate layer + # image_embeds = id_embeds['pooler_output'] + # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) return image_embeds + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler): + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + return stepped_latents def forward( self, @@ -509,26 +530,7 @@ class DiffusionFeatureExtractor4(nn.Module): if model is not None and hasattr(model, 'get_stepped_pred'): stepped_latents = model.get_stepped_pred(noise_pred, noise) else: - # stepped_latents = noise - noise_pred - # first we step the scheduler from current timestep to the very end for a full denoise - bs = noise_pred.shape[0] - noise_pred_chunks = torch.chunk(noise_pred, bs) - timestep_chunks = torch.chunk(timesteps, bs) - noisy_latent_chunks = torch.chunk(noisy_latents, bs) - stepped_chunks = [] - for idx in range(bs): - model_output = noise_pred_chunks[idx] - timestep = timestep_chunks[idx] - scheduler._step_index = None - scheduler._init_step_index(timestep) - sample = noisy_latent_chunks[idx].to(torch.float32) - - sigma = scheduler.sigmas[scheduler.step_index] - sigma_next = scheduler.sigmas[-1] # use last sigma for final step - prev_sample = sample + (sigma_next - sigma) * model_output - stepped_chunks.append(prev_sample) - - stepped_latents = torch.cat(stepped_chunks, dim=0) + stepped_latents = self.step_latents(noise, noise_pred, noisy_latents, timesteps, scheduler) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) @@ -590,6 +592,52 @@ class DiffusionFeatureExtractor4(nn.Module): self.step += 1 return total_loss + +class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__(device=device, dtype=dtype, vae=vae) + self.version = 5 + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler, total_steps: int = 1000, eps: float = 1e-6): + bs = noise_pred.shape[0] + + # Chunk inputs per-sample (keeps existing structure) + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + stepped_chunks = [] + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(stepped) + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + # stepped_latents = torch.cat(stepped_chunks, dim=0) + predicted_images = torch.cat(x0_pred_chunks, dim=0) + # return stepped_latents, predicted_images + return predicted_images def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: if model_path == "v3": @@ -600,6 +648,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: dfe = DiffusionFeatureExtractor4(vae=vae) dfe.eval() return dfe + if model_path == "v5": + dfe = DiffusionFeatureExtractor5(vae=vae) + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors