mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added code to handle diffusion feature extraction loss
This commit is contained in:
@@ -32,6 +32,7 @@ from torchvision import transforms
|
||||
from diffusers import EMAModel
|
||||
import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -78,6 +79,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
||||
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
||||
|
||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
@@ -178,6 +181,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# move back to cpu
|
||||
self.sd.text_encoder_to('cpu')
|
||||
flush()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path)
|
||||
self.dfe.to(self.device_torch)
|
||||
self.dfe.eval()
|
||||
|
||||
|
||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||
@@ -285,6 +293,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
additional_loss = 0.0
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
@@ -367,7 +376,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
target = noise
|
||||
|
||||
|
||||
if self.dfe is not None:
|
||||
# do diffusion feature extraction on target
|
||||
with torch.no_grad():
|
||||
rectified_flow_target = noise.float() - batch.latents.float()
|
||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean")
|
||||
|
||||
if target is None:
|
||||
target = noise
|
||||
|
||||
@@ -487,7 +506,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss + norm_std_loss
|
||||
|
||||
|
||||
return loss
|
||||
return loss + additional_loss
|
||||
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
|
||||
@@ -400,6 +400,9 @@ class TrainConfig:
|
||||
self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
|
||||
# bypass the guidance embedding for training. For open flux with guidance embedding
|
||||
self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
|
||||
|
||||
# diffusion feature extractor
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
55
toolkit/models/diffusion_feature_extraction.py
Normal file
55
toolkit/models/diffusion_feature_extraction.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import os
|
||||
from torch import nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
class DFEBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.act = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor(nn.Module):
|
||||
def __init__(self, in_channels=32):
|
||||
super().__init__()
|
||||
num_blocks = 6
|
||||
self.conv_in = nn.Conv2d(in_channels, 512, 1)
|
||||
self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1)
|
||||
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
|
||||
self.conv_out = nn.Conv2d(512, 512, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.conv_pool(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def load_dfe(model_path) -> DiffusionFeatureExtractor:
|
||||
dfe = DiffusionFeatureExtractor()
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
if model_path.endswith('.safetensors'):
|
||||
state_dict = load_file(model_path)
|
||||
else:
|
||||
state_dict = torch.load(model_path, weights_only=True)
|
||||
if 'model_state_dict' in state_dict:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
|
||||
dfe.load_state_dict(state_dict)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
Reference in New Issue
Block a user