diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 479b414..1b9f35b 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -414,15 +414,41 @@ class SDTrainer(BaseSDTrainProcess): if self.dfe is not None: if self.dfe.version == 1: - # do diffusion feature extraction on target + model = self.sd + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + self.sd.noise_scheduler._step_index = None + self.sd.noise_scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] + sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + pred_features = self.dfe(stepped_latents.float()) with torch.no_grad(): - rectified_flow_target = noise.float() - batch.latents.float() - target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32)) + # scale dfe so it is weaker at higher noise levels + dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) - # do diffusion feature extraction on prediction - pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) - additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ - self.train_config.diffusion_feature_extractor_weight + dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c1fd396..c682b58 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -454,8 +454,12 @@ class TrainConfig: self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) # diffusion feature extractor - self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) - self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0) + self.latent_feature_extractor_path = kwargs.get('latent_feature_extractor_path', None) + self.latent_feature_loss_weight = kwargs.get('latent_feature_loss_weight', 1.0) + + # we use this in the code, but it really needs to be called latent_feature_extractor as that makes more sense with new architecture + self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', self.latent_feature_extractor_path) + self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', self.latent_feature_loss_weight) # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 01d7f27..0f7bff0 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -127,18 +127,20 @@ class DFEBlock(nn.Module): self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.act = nn.GELU() + self.proj = nn.Conv2d(channels, channels, 1) def forward(self, x): x_in = x x = self.conv1(x) x = self.conv2(x) x = self.act(x) + x = self.proj(x) x = x + x_in return x class DiffusionFeatureExtractor(nn.Module): - def __init__(self, in_channels=32): + def __init__(self, in_channels=16): super().__init__() self.version = 1 num_blocks = 6 diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4bffc81..3041b1a 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2538,8 +2538,8 @@ class StableDiffusion: # Move to vae to device if on cpu if self.vae.device == 'cpu': - self.vae.to(self.device) - latents = latents.to(device, dtype=dtype) + self.vae.to(self.device_torch) + latents = latents.to(self.device_torch, dtype=self.torch_dtype) latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] images = self.vae.decode(latents).sample images = images.to(device, dtype=dtype)