More work on mean flow loss. Moved it to an adapter. Still not functioning properly though.

This commit is contained in:
Jaret Burkett
2025-06-16 07:17:35 -06:00
parent c0314ba325
commit 1c2b7298dd
6 changed files with 323 additions and 165 deletions

View File

@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.control_lora_adapter import ControlLoraAdapter
from toolkit.models.mean_flow_adapter import MeanFlowAdapter
from toolkit.models.i2v_adapter import I2VAdapter
from toolkit.models.subpixel_adapter import SubpixelAdapter
from toolkit.models.ilora import InstantLoRAModule
@@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module):
self.single_value_adapter: SingleValueAdapter = None
self.redux_adapter: ReduxImageEncoder = None
self.control_lora: ControlLoraAdapter = None
self.mean_flow_adapter: MeanFlowAdapter = None
self.subpixel_adapter: SubpixelAdapter = None
self.i2v_adapter: I2VAdapter = None
@@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module):
dtype=self.sd_ref().dtype,
)
self.load_state_dict(loaded_state_dict, strict=False)
@property
def do_direct_save(self):
# some adapters save their weights directly, others like ip adapters split the state dict
if self.config.train_only_image_encoder:
return True
if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']:
return True
return False
def setup_adapter(self):
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
@@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'redux':
vision_hidden_size = self.vision_encoder.config.hidden_size
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
elif self.adapter_type == 'mean_flow':
self.mean_flow_adapter = MeanFlowAdapter(
self,
sd=self.sd_ref(),
config=self.config,
train_config=self.train_config
)
elif self.adapter_type == 'control_lora':
self.control_lora = ControlLoraAdapter(
self,
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]:
return
if self.config.type == 'photo_maker':
try:
@@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module):
new_dict[k + '.' + k2] = v2
self.control_lora.load_weights(new_dict, strict=strict)
if self.adapter_type == 'mean_flow':
# state dict is seperated. so recombine it
new_dict = {}
for k, v in state_dict.items():
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.mean_flow_adapter.load_weights(new_dict, strict=strict)
if self.adapter_type == 'i2v':
# state dict is seperated. so recombine it
new_dict = {}
@@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module):
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'mean_flow':
d = self.mean_flow_adapter.get_state_dict()
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'i2v':
d = self.i2v_adapter.get_state_dict()
for k, v in d.items():
@@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']:
return prompt
elif self.adapter_type == 'text_encoder':
# todo allow for training
@@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module):
param_list = self.control_lora.get_params()
for param in param_list:
yield param
elif self.config.type == 'mean_flow':
param_list = self.mean_flow_adapter.get_params()
for param in param_list:
yield param
elif self.config.type == 'i2v':
param_list = self.i2v_adapter.get_params()
for param in param_list:

View File

@@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module):
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
network_kwargs['target_lin_modules'] = sd.target_lora_modules
if 'ignore_if_contains' not in network_kwargs:
network_kwargs['ignore_if_contains'] = []

View File

@@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
timesteps_emb = timesteps_emb_start + timesteps_emb_end
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)

View File

@@ -0,0 +1,282 @@
import inspect
import weakref
import torch
from typing import TYPE_CHECKING
from toolkit.lora_special import LoRASpecialNetwork
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import (
CombinedTimestepTextProjEmbeddings,
CombinedTimestepGuidanceTextProjEmbeddings,
)
from functools import partial
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
def mean_flow_time_text_embed_forward(
self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection
):
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = timesteps_emb.dtype
timesteps_emb = timesteps_emb.to(torch.float32)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
)
timesteps_emb = timesteps_emb.to(orig_dtype)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(
self: CombinedTimestepGuidanceTextProjEmbeddings,
timestep,
guidance,
pooled_projection,
):
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(
guidance_proj.to(dtype=pooled_projection.dtype)
) # (N, D)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = timesteps_emb.dtype
timesteps_emb = timesteps_emb.to(torch.float32)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
)
timesteps_emb = timesteps_emb.to(orig_dtype)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(
transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings
):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)
class MeanFlowAdapter(torch.nn.Module):
def __init__(
self,
adapter: "CustomAdapter",
sd: "StableDiffusion",
config: "AdapterConfig",
train_config: "TrainConfig",
):
super().__init__()
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref = weakref.ref(sd)
self.model_config: ModelConfig = sd.model_config
self.network_config = config.lora_config
self.train_config = train_config
self.device_torch = sd.device_torch
self.lora = None
if self.network_config is not None:
network_kwargs = (
{}
if self.network_config.network_kwargs is None
else self.network_config.network_kwargs
)
if hasattr(sd, "target_lora_modules"):
network_kwargs["target_lin_modules"] = sd.target_lora_modules
if "ignore_if_contains" not in network_kwargs:
network_kwargs["ignore_if_contains"] = []
self.lora = LoRASpecialNetwork(
text_encoder=sd.text_encoder,
unet=sd.unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_lumina2=self.model_config.is_lumina2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=False,
is_lorm=False,
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
is_transformer=sd.is_transformer,
base_model=sd,
**network_kwargs,
)
self.lora.force_to(self.device_torch, dtype=torch.float32)
self.lora._update_torch_multiplier()
self.lora.apply_to(
sd.text_encoder,
sd.unet,
self.train_config.train_text_encoder,
self.train_config.train_unet,
)
self.lora.can_merge_in = False
self.lora.prepare_grad_etc(sd.text_encoder, sd.unet)
if self.train_config.gradient_checkpointing:
self.lora.enable_gradient_checkpointing()
emb_dim = None
if self.model_config.arch in ["flux", "flex2", "flex2"]:
transformer: FluxTransformer2DModel = sd.unet
emb_dim = (
transformer.config.num_attention_heads
* transformer.config.attention_head_dim
)
convert_flux_to_mean_flow(transformer)
else:
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
self.mean_flow_timestep_embedder = torch.nn.Linear(
emb_dim * 2,
emb_dim,
)
# make the model function as before adding this adapter by initializing the weights
with torch.no_grad():
self.mean_flow_timestep_embedder.weight.zero_()
self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim)
self.mean_flow_timestep_embedder.bias.zero_()
self.mean_flow_timestep_embedder.to(self.device_torch)
# add our adapter as a weak ref
if self.model_config.arch in ["flux", "flex2", "flex2"]:
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
def get_params(self):
if self.lora is not None:
config = {
"text_encoder_lr": self.train_config.lr,
"unet_lr": self.train_config.lr,
}
sig = inspect.signature(self.lora.prepare_optimizer_params)
if "default_lr" in sig.parameters:
config["default_lr"] = self.train_config.lr
if "learning_rate" in sig.parameters:
config["learning_rate"] = self.train_config.lr
params_net = self.lora.prepare_optimizer_params(**config)
# we want only tensors here
params = []
for p in params_net:
if isinstance(p, dict):
params += p["params"]
elif isinstance(p, torch.Tensor):
params.append(p)
elif isinstance(p, list):
params += p
else:
params = []
# make sure the embedder is float32
self.mean_flow_timestep_embedder.to(torch.float32)
self.mean_flow_timestep_embedder.requires_grad = True
self.mean_flow_timestep_embedder.train()
params += list(self.mean_flow_timestep_embedder.parameters())
# we need to be able to yield from the list like yield from params
return params
def load_weights(self, state_dict, strict=True):
lora_sd = {}
mean_flow_embedder_sd = {}
for key, value in state_dict.items():
if "mean_flow_timestep_embedder" in key:
new_key = key.replace("transformer.mean_flow_timestep_embedder.", "")
mean_flow_embedder_sd[new_key] = value
else:
lora_sd[key] = value
# todo process state dict before loading for models that need it
if self.lora is not None:
self.lora.load_weights(lora_sd)
self.mean_flow_timestep_embedder.load_state_dict(
mean_flow_embedder_sd, strict=False
)
def get_state_dict(self):
if self.lora is not None:
lora_sd = self.lora.get_state_dict(dtype=torch.float32)
else:
lora_sd = {}
# todo make sure we match loras elseware.
mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict()
for key, value in mean_flow_embedder_sd.items():
lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value
return lora_sd
@property
def is_active(self):
return self.adapter_ref().is_active