Added code to handle diffusion feature extraction loss

This commit is contained in:
Jaret Burkett
2025-01-21 14:21:34 -07:00
parent 6a8e3d8610
commit 29122b1a54
3 changed files with 79 additions and 2 deletions

View File

@@ -32,6 +32,7 @@ from torchvision import transforms
from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
def flush():
@@ -78,6 +79,8 @@ class SDTrainer(BaseSDTrainProcess):
self.cached_blank_embeds: Optional[PromptEmbeds] = None
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
self.dfe: Optional[DiffusionFeatureExtractor] = None
def before_model_load(self):
@@ -178,6 +181,11 @@ class SDTrainer(BaseSDTrainProcess):
# move back to cpu
self.sd.text_encoder_to('cpu')
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path)
self.dfe.to(self.device_torch)
self.dfe.eval()
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
@@ -285,6 +293,7 @@ class SDTrainer(BaseSDTrainProcess):
):
loss_target = self.train_config.loss_target
is_reg = any(batch.get_is_reg_list())
additional_loss = 0.0
prior_mask_multiplier = None
target_mask_multiplier = None
@@ -367,7 +376,17 @@ class SDTrainer(BaseSDTrainProcess):
target = (noise - batch.latents).detach()
else:
target = noise
if self.dfe is not None:
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean")
if target is None:
target = noise
@@ -487,7 +506,7 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss + norm_std_loss
return loss
return loss + additional_loss
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch

View File

@@ -400,6 +400,9 @@ class TrainConfig:
self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
# bypass the guidance embedding for training. For open flux with guidance embedding
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
class ModelConfig:

View File

@@ -0,0 +1,55 @@
import torch
import os
from torch import nn
from safetensors.torch import load_file
class DFEBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.act = nn.GELU()
def forward(self, x):
x_in = x
x = self.conv1(x)
x = self.conv2(x)
x = self.act(x)
x = x + x_in
return x
class DiffusionFeatureExtractor(nn.Module):
def __init__(self, in_channels=32):
super().__init__()
num_blocks = 6
self.conv_in = nn.Conv2d(in_channels, 512, 1)
self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1)
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
self.conv_out = nn.Conv2d(512, 512, 1)
def forward(self, x):
x = self.conv_in(x)
x = self.conv_pool(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)
return x
def load_dfe(model_path) -> DiffusionFeatureExtractor:
dfe = DiffusionFeatureExtractor()
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# if it ende with safetensors
if model_path.endswith('.safetensors'):
state_dict = load_file(model_path)
else:
state_dict = torch.load(model_path, weights_only=True)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
dfe.load_state_dict(state_dict)
dfe.eval()
return dfe