mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added v2 of dfp
This commit is contained in:
@@ -378,15 +378,32 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
if self.dfe is not None:
|
if self.dfe is not None:
|
||||||
# do diffusion feature extraction on target
|
if self.dfe.version == 1:
|
||||||
with torch.no_grad():
|
# do diffusion feature extraction on target
|
||||||
rectified_flow_target = noise.float() - batch.latents.float()
|
with torch.no_grad():
|
||||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
rectified_flow_target = noise.float() - batch.latents.float()
|
||||||
|
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||||
# do diffusion feature extraction on prediction
|
|
||||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
# do diffusion feature extraction on prediction
|
||||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||||
self.train_config.diffusion_feature_extractor_weight
|
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||||
|
self.train_config.diffusion_feature_extractor_weight
|
||||||
|
else:
|
||||||
|
# version 2
|
||||||
|
# do diffusion feature extraction on target
|
||||||
|
with torch.no_grad():
|
||||||
|
rectified_flow_target = noise.float() - batch.latents.float()
|
||||||
|
target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||||
|
|
||||||
|
# do diffusion feature extraction on prediction
|
||||||
|
pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||||
|
|
||||||
|
dfe_loss = 0.0
|
||||||
|
for i in range(len(target_feature_list)):
|
||||||
|
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||||
|
|
||||||
|
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||||
|
|
||||||
|
|
||||||
if target is None:
|
if target is None:
|
||||||
target = noise
|
target = noise
|
||||||
|
|||||||
@@ -2,6 +2,116 @@ import torch
|
|||||||
import os
|
import os
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
||||||
|
self.norm1 = nn.GroupNorm(8, out_channels)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(8, out_channels)
|
||||||
|
self.skip = nn.Conv2d(in_channels, out_channels,
|
||||||
|
1) if in_channels != out_channels else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = self.skip(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = F.silu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = F.silu(x + identity)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionFeatureExtractor2(nn.Module):
|
||||||
|
def __init__(self, in_channels=32):
|
||||||
|
super().__init__()
|
||||||
|
self.version = 2
|
||||||
|
|
||||||
|
# Path 1: Upsample to 512x512 (1, 64, 512, 512)
|
||||||
|
self.up_path = nn.ModuleList([
|
||||||
|
nn.Conv2d(in_channels, 64, 3, padding=1),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(64, 64),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(64, 64),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(64, 64),
|
||||||
|
nn.Conv2d(64, 64, 3, padding=1),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Path 2: Upsample to 256x256 (1, 128, 256, 256)
|
||||||
|
self.path2 = nn.ModuleList([
|
||||||
|
nn.Conv2d(in_channels, 128, 3, padding=1),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(128, 128),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(128, 128),
|
||||||
|
nn.Conv2d(128, 128, 3, padding=1),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Path 3: Upsample to 128x128 (1, 256, 128, 128)
|
||||||
|
self.path3 = nn.ModuleList([
|
||||||
|
nn.Conv2d(in_channels, 256, 3, padding=1),
|
||||||
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
ResBlock(256, 256),
|
||||||
|
nn.Conv2d(256, 256, 3, padding=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Path 4: Original size (1, 512, 64, 64)
|
||||||
|
self.path4 = nn.ModuleList([
|
||||||
|
nn.Conv2d(in_channels, 512, 3, padding=1),
|
||||||
|
ResBlock(512, 512),
|
||||||
|
ResBlock(512, 512),
|
||||||
|
nn.Conv2d(512, 512, 3, padding=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Path 5: Downsample to 32x32 (1, 512, 32, 32)
|
||||||
|
self.path5 = nn.ModuleList([
|
||||||
|
nn.Conv2d(in_channels, 512, 3, padding=1),
|
||||||
|
ResBlock(512, 512),
|
||||||
|
nn.AvgPool2d(2),
|
||||||
|
ResBlock(512, 512),
|
||||||
|
nn.Conv2d(512, 512, 3, padding=1)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
# Path 1: 512x512
|
||||||
|
x1 = x
|
||||||
|
for layer in self.up_path:
|
||||||
|
x1 = layer(x1)
|
||||||
|
outputs.append(x1) # [1, 64, 512, 512]
|
||||||
|
|
||||||
|
# Path 2: 256x256
|
||||||
|
x2 = x
|
||||||
|
for layer in self.path2:
|
||||||
|
x2 = layer(x2)
|
||||||
|
outputs.append(x2) # [1, 128, 256, 256]
|
||||||
|
|
||||||
|
# Path 3: 128x128
|
||||||
|
x3 = x
|
||||||
|
for layer in self.path3:
|
||||||
|
x3 = layer(x3)
|
||||||
|
outputs.append(x3) # [1, 256, 128, 128]
|
||||||
|
|
||||||
|
# Path 4: 64x64
|
||||||
|
x4 = x
|
||||||
|
for layer in self.path4:
|
||||||
|
x4 = layer(x4)
|
||||||
|
outputs.append(x4) # [1, 512, 64, 64]
|
||||||
|
|
||||||
|
# Path 5: 32x32
|
||||||
|
x5 = x
|
||||||
|
for layer in self.path5:
|
||||||
|
x5 = layer(x5)
|
||||||
|
outputs.append(x5) # [1, 512, 32, 32]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class DFEBlock(nn.Module):
|
class DFEBlock(nn.Module):
|
||||||
@@ -23,6 +133,7 @@ class DFEBlock(nn.Module):
|
|||||||
class DiffusionFeatureExtractor(nn.Module):
|
class DiffusionFeatureExtractor(nn.Module):
|
||||||
def __init__(self, in_channels=32):
|
def __init__(self, in_channels=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.version = 1
|
||||||
num_blocks = 6
|
num_blocks = 6
|
||||||
self.conv_in = nn.Conv2d(in_channels, 512, 1)
|
self.conv_in = nn.Conv2d(in_channels, 512, 1)
|
||||||
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
|
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
|
||||||
@@ -37,7 +148,6 @@ class DiffusionFeatureExtractor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def load_dfe(model_path) -> DiffusionFeatureExtractor:
|
def load_dfe(model_path) -> DiffusionFeatureExtractor:
|
||||||
dfe = DiffusionFeatureExtractor()
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
# if it ende with safetensors
|
# if it ende with safetensors
|
||||||
@@ -48,6 +158,11 @@ def load_dfe(model_path) -> DiffusionFeatureExtractor:
|
|||||||
if 'model_state_dict' in state_dict:
|
if 'model_state_dict' in state_dict:
|
||||||
state_dict = state_dict['model_state_dict']
|
state_dict = state_dict['model_state_dict']
|
||||||
|
|
||||||
|
if 'conv_in.weight' in state_dict:
|
||||||
|
dfe = DiffusionFeatureExtractor()
|
||||||
|
else:
|
||||||
|
dfe = DiffusionFeatureExtractor2()
|
||||||
|
|
||||||
dfe.load_state_dict(state_dict)
|
dfe.load_state_dict(state_dict)
|
||||||
dfe.eval()
|
dfe.eval()
|
||||||
return dfe
|
return dfe
|
||||||
|
|||||||
@@ -1285,20 +1285,21 @@ class FluxWithCFGPipeline(FluxPipeline):
|
|||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
lora_scale=lora_scale,
|
lora_scale=lora_scale,
|
||||||
)
|
)
|
||||||
(
|
if guidance_scale > 1.00001:
|
||||||
negative_prompt_embeds,
|
(
|
||||||
negative_pooled_prompt_embeds,
|
negative_prompt_embeds,
|
||||||
negative_text_ids,
|
negative_pooled_prompt_embeds,
|
||||||
) = self.encode_prompt(
|
negative_text_ids,
|
||||||
prompt=negative_prompt,
|
) = self.encode_prompt(
|
||||||
prompt_2=negative_prompt_2,
|
prompt=negative_prompt,
|
||||||
prompt_embeds=negative_prompt_embeds,
|
prompt_2=negative_prompt_2,
|
||||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
prompt_embeds=negative_prompt_embeds,
|
||||||
device=device,
|
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
device=device,
|
||||||
max_sequence_length=max_sequence_length,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
lora_scale=lora_scale,
|
max_sequence_length=max_sequence_length,
|
||||||
)
|
lora_scale=lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Prepare latent variables
|
# 4. Prepare latent variables
|
||||||
num_channels_latents = self.transformer.config.in_channels // 4
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
@@ -1361,21 +1362,25 @@ class FluxWithCFGPipeline(FluxPipeline):
|
|||||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
if guidance_scale > 1.00001:
|
||||||
|
# todo combine these
|
||||||
|
noise_pred_uncond = self.transformer(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
pooled_projections=negative_pooled_prompt_embeds,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
txt_ids=negative_text_ids,
|
||||||
|
img_ids=latent_image_ids,
|
||||||
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
# todo combine these
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
noise_pred_uncond = self.transformer(
|
|
||||||
hidden_states=latents,
|
else:
|
||||||
timestep=timestep / 1000,
|
noise_pred = noise_pred_text
|
||||||
guidance=guidance,
|
|
||||||
pooled_projections=negative_pooled_prompt_embeds,
|
|
||||||
encoder_hidden_states=negative_prompt_embeds,
|
|
||||||
txt_ids=negative_text_ids,
|
|
||||||
img_ids=latent_image_ids,
|
|
||||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user