diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5e405333..3747972d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -378,15 +378,32 @@ class SDTrainer(BaseSDTrainProcess): target = noise if self.dfe is not None: - # do diffusion feature extraction on target - with torch.no_grad(): - rectified_flow_target = noise.float() - batch.latents.float() - target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) - - # do diffusion feature extraction on prediction - pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) - additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ - self.train_config.diffusion_feature_extractor_weight + if self.dfe.version == 1: + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ + self.train_config.diffusion_feature_extractor_weight + else: + # version 2 + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + + dfe_loss = 0.0 + for i in range(len(target_feature_list)): + dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") + + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 + if target is None: target = noise diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index e3b42ff3..e36b91a4 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -2,6 +2,116 @@ import torch import os from torch import nn from safetensors.torch import load_file +import torch.nn.functional as F + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) + self.norm1 = nn.GroupNorm(8, out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) + self.norm2 = nn.GroupNorm(8, out_channels) + self.skip = nn.Conv2d(in_channels, out_channels, + 1) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + identity = self.skip(x) + x = self.conv1(x) + x = self.norm1(x) + x = F.silu(x) + x = self.conv2(x) + x = self.norm2(x) + x = F.silu(x + identity) + return x + + +class DiffusionFeatureExtractor2(nn.Module): + def __init__(self, in_channels=32): + super().__init__() + self.version = 2 + + # Path 1: Upsample to 512x512 (1, 64, 512, 512) + self.up_path = nn.ModuleList([ + nn.Conv2d(in_channels, 64, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(64, 64), + nn.Conv2d(64, 64, 3, padding=1), + ]) + + # Path 2: Upsample to 256x256 (1, 128, 256, 256) + self.path2 = nn.ModuleList([ + nn.Conv2d(in_channels, 128, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(128, 128), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(128, 128), + nn.Conv2d(128, 128, 3, padding=1), + ]) + + # Path 3: Upsample to 128x128 (1, 256, 128, 128) + self.path3 = nn.ModuleList([ + nn.Conv2d(in_channels, 256, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ResBlock(256, 256), + nn.Conv2d(256, 256, 3, padding=1) + ]) + + # Path 4: Original size (1, 512, 64, 64) + self.path4 = nn.ModuleList([ + nn.Conv2d(in_channels, 512, 3, padding=1), + ResBlock(512, 512), + ResBlock(512, 512), + nn.Conv2d(512, 512, 3, padding=1) + ]) + + # Path 5: Downsample to 32x32 (1, 512, 32, 32) + self.path5 = nn.ModuleList([ + nn.Conv2d(in_channels, 512, 3, padding=1), + ResBlock(512, 512), + nn.AvgPool2d(2), + ResBlock(512, 512), + nn.Conv2d(512, 512, 3, padding=1) + ]) + + def forward(self, x): + outputs = [] + + # Path 1: 512x512 + x1 = x + for layer in self.up_path: + x1 = layer(x1) + outputs.append(x1) # [1, 64, 512, 512] + + # Path 2: 256x256 + x2 = x + for layer in self.path2: + x2 = layer(x2) + outputs.append(x2) # [1, 128, 256, 256] + + # Path 3: 128x128 + x3 = x + for layer in self.path3: + x3 = layer(x3) + outputs.append(x3) # [1, 256, 128, 128] + + # Path 4: 64x64 + x4 = x + for layer in self.path4: + x4 = layer(x4) + outputs.append(x4) # [1, 512, 64, 64] + + # Path 5: 32x32 + x5 = x + for layer in self.path5: + x5 = layer(x5) + outputs.append(x5) # [1, 512, 32, 32] + + return outputs class DFEBlock(nn.Module): @@ -23,6 +133,7 @@ class DFEBlock(nn.Module): class DiffusionFeatureExtractor(nn.Module): def __init__(self, in_channels=32): super().__init__() + self.version = 1 num_blocks = 6 self.conv_in = nn.Conv2d(in_channels, 512, 1) self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) @@ -37,7 +148,6 @@ class DiffusionFeatureExtractor(nn.Module): def load_dfe(model_path) -> DiffusionFeatureExtractor: - dfe = DiffusionFeatureExtractor() if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors @@ -48,6 +158,11 @@ def load_dfe(model_path) -> DiffusionFeatureExtractor: if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] + if 'conv_in.weight' in state_dict: + dfe = DiffusionFeatureExtractor() + else: + dfe = DiffusionFeatureExtractor2() + dfe.load_state_dict(state_dict) dfe.eval() return dfe diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index c0509ee1..f0cfb91a 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -1285,20 +1285,21 @@ class FluxWithCFGPipeline(FluxPipeline): max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) - ( - negative_prompt_embeds, - negative_pooled_prompt_embeds, - negative_text_ids, - ) = self.encode_prompt( - prompt=negative_prompt, - prompt_2=negative_prompt_2, - prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) + if guidance_scale > 1.00001: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -1361,21 +1362,25 @@ class FluxWithCFGPipeline(FluxPipeline): joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if guidance_scale > 1.00001: + # todo combine these + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] - # todo combine these - noise_pred_uncond = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + else: + noise_pred = noise_pred_text # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype