mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
More work on mean flow loss. Moved it to an adapter. Still not functioning properly though.
This commit is contained in:
@@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.models.flux import convert_flux_to_mean_flow
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -135,9 +134,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.loss_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.unet)
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -634,102 +630,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
return loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||
# adapted from the work of lodestonerock
|
||||
# ------------------------------------------------------------------
|
||||
def get_mean_flow_loss_wip(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
|
||||
time_end = timesteps.float() / 1000
|
||||
# for timestep_r, we need values from timestep_end to 0.0 randomly
|
||||
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
|
||||
|
||||
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
|
||||
# Compute noised data points
|
||||
# lerp_vector = noisy_latents
|
||||
# compute instantaneous vector
|
||||
instantaneous_vector = noise - batch_latents
|
||||
|
||||
# finite difference method
|
||||
epsilon_fd = 1e-3
|
||||
jitter_std = 1e-4
|
||||
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
|
||||
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
|
||||
|
||||
# f(x + epsilon * v) for the primal (we backprop through here)
|
||||
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
|
||||
mean_vec_val_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
|
||||
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
|
||||
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
|
||||
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=perturbed_lerp_vector,
|
||||
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
|
||||
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
|
||||
mean_vec_grad = mean_vec_grad_fd
|
||||
|
||||
|
||||
# calculate the regression target the mean vector
|
||||
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
|
||||
|
||||
# 5) MSE loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_val_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
with torch.no_grad():
|
||||
loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
loss = loss / loss_mean
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
@@ -811,7 +711,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
# eps = (t_frac - r_frac) / 2
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
|
||||
@@ -575,10 +575,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
direct_save = False
|
||||
if self.adapter_config.train_only_image_encoder:
|
||||
direct_save = True
|
||||
if self.adapter_config.type == 'redux':
|
||||
direct_save = True
|
||||
if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']:
|
||||
direct_save = True
|
||||
elif isinstance(self.adapter, CustomAdapter):
|
||||
direct_save = self.adapter.do_direct_save
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
|
||||
@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.mean_flow_adapter import MeanFlowAdapter
|
||||
from toolkit.models.i2v_adapter import I2VAdapter
|
||||
from toolkit.models.subpixel_adapter import SubpixelAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
@@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
self.mean_flow_adapter: MeanFlowAdapter = None
|
||||
self.subpixel_adapter: SubpixelAdapter = None
|
||||
self.i2v_adapter: I2VAdapter = None
|
||||
|
||||
@@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module):
|
||||
dtype=self.sd_ref().dtype,
|
||||
)
|
||||
self.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
@property
|
||||
def do_direct_save(self):
|
||||
# some adapters save their weights directly, others like ip adapters split the state dict
|
||||
if self.config.train_only_image_encoder:
|
||||
return True
|
||||
if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def setup_adapter(self):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
@@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'redux':
|
||||
vision_hidden_size = self.vision_encoder.config.hidden_size
|
||||
self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
|
||||
elif self.adapter_type == 'mean_flow':
|
||||
self.mean_flow_adapter = MeanFlowAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
elif self.adapter_type == 'control_lora':
|
||||
self.control_lora = ControlLoraAdapter(
|
||||
self,
|
||||
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'mean_flow':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.mean_flow_adapter.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'i2v':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
@@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'mean_flow':
|
||||
d = self.mean_flow_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'i2v':
|
||||
d = self.i2v_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
@@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'mean_flow':
|
||||
param_list = self.mean_flow_adapter.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'i2v':
|
||||
param_list = self.i2v_adapter.get_params()
|
||||
for param in param_list:
|
||||
|
||||
@@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
network_kwargs['target_lin_modules'] = sd.target_lora_modules
|
||||
|
||||
if 'ignore_if_contains' not in network_kwargs:
|
||||
network_kwargs['ignore_if_contains'] = []
|
||||
|
||||
@@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux(
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
|
||||
|
||||
timesteps_emb = timesteps_emb_start + timesteps_emb_end
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
|
||||
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
282
toolkit/models/mean_flow_adapter.py
Normal file
282
toolkit/models/mean_flow_adapter.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(
|
||||
self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection
|
||||
):
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = timesteps_emb.dtype
|
||||
timesteps_emb = timesteps_emb.to(torch.float32)
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
|
||||
)
|
||||
timesteps_emb = timesteps_emb.to(orig_dtype)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(
|
||||
self: CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
timestep,
|
||||
guidance,
|
||||
pooled_projection,
|
||||
):
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(
|
||||
guidance_proj.to(dtype=pooled_projection.dtype)
|
||||
) # (N, D)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = timesteps_emb.dtype
|
||||
timesteps_emb = timesteps_emb.to(torch.float32)
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1)
|
||||
)
|
||||
timesteps_emb = timesteps_emb.to(orig_dtype)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(
|
||||
transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings
|
||||
):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MeanFlowAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: "CustomAdapter",
|
||||
sd: "StableDiffusion",
|
||||
config: "AdapterConfig",
|
||||
train_config: "TrainConfig",
|
||||
):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
self.network_config = config.lora_config
|
||||
self.train_config = train_config
|
||||
self.device_torch = sd.device_torch
|
||||
self.lora = None
|
||||
|
||||
if self.network_config is not None:
|
||||
network_kwargs = (
|
||||
{}
|
||||
if self.network_config.network_kwargs is None
|
||||
else self.network_config.network_kwargs
|
||||
)
|
||||
if hasattr(sd, "target_lora_modules"):
|
||||
network_kwargs["target_lin_modules"] = sd.target_lora_modules
|
||||
|
||||
if "ignore_if_contains" not in network_kwargs:
|
||||
network_kwargs["ignore_if_contains"] = []
|
||||
|
||||
self.lora = LoRASpecialNetwork(
|
||||
text_encoder=sd.text_encoder,
|
||||
unet=sd.unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_lumina2=self.model_config.is_lumina2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=False,
|
||||
is_lorm=False,
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=sd.is_transformer,
|
||||
base_model=sd,
|
||||
**network_kwargs,
|
||||
)
|
||||
self.lora.force_to(self.device_torch, dtype=torch.float32)
|
||||
self.lora._update_torch_multiplier()
|
||||
self.lora.apply_to(
|
||||
sd.text_encoder,
|
||||
sd.unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet,
|
||||
)
|
||||
self.lora.can_merge_in = False
|
||||
self.lora.prepare_grad_etc(sd.text_encoder, sd.unet)
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.lora.enable_gradient_checkpointing()
|
||||
|
||||
emb_dim = None
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
emb_dim = (
|
||||
transformer.config.num_attention_heads
|
||||
* transformer.config.attention_head_dim
|
||||
)
|
||||
convert_flux_to_mean_flow(transformer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
|
||||
|
||||
self.mean_flow_timestep_embedder = torch.nn.Linear(
|
||||
emb_dim * 2,
|
||||
emb_dim,
|
||||
)
|
||||
|
||||
# make the model function as before adding this adapter by initializing the weights
|
||||
with torch.no_grad():
|
||||
self.mean_flow_timestep_embedder.weight.zero_()
|
||||
self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim)
|
||||
self.mean_flow_timestep_embedder.bias.zero_()
|
||||
|
||||
self.mean_flow_timestep_embedder.to(self.device_torch)
|
||||
|
||||
# add our adapter as a weak ref
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
|
||||
def get_params(self):
|
||||
if self.lora is not None:
|
||||
config = {
|
||||
"text_encoder_lr": self.train_config.lr,
|
||||
"unet_lr": self.train_config.lr,
|
||||
}
|
||||
sig = inspect.signature(self.lora.prepare_optimizer_params)
|
||||
if "default_lr" in sig.parameters:
|
||||
config["default_lr"] = self.train_config.lr
|
||||
if "learning_rate" in sig.parameters:
|
||||
config["learning_rate"] = self.train_config.lr
|
||||
params_net = self.lora.prepare_optimizer_params(**config)
|
||||
|
||||
# we want only tensors here
|
||||
params = []
|
||||
for p in params_net:
|
||||
if isinstance(p, dict):
|
||||
params += p["params"]
|
||||
elif isinstance(p, torch.Tensor):
|
||||
params.append(p)
|
||||
elif isinstance(p, list):
|
||||
params += p
|
||||
else:
|
||||
params = []
|
||||
|
||||
# make sure the embedder is float32
|
||||
self.mean_flow_timestep_embedder.to(torch.float32)
|
||||
self.mean_flow_timestep_embedder.requires_grad = True
|
||||
self.mean_flow_timestep_embedder.train()
|
||||
|
||||
params += list(self.mean_flow_timestep_embedder.parameters())
|
||||
|
||||
# we need to be able to yield from the list like yield from params
|
||||
|
||||
return params
|
||||
|
||||
def load_weights(self, state_dict, strict=True):
|
||||
lora_sd = {}
|
||||
mean_flow_embedder_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
if "mean_flow_timestep_embedder" in key:
|
||||
new_key = key.replace("transformer.mean_flow_timestep_embedder.", "")
|
||||
mean_flow_embedder_sd[new_key] = value
|
||||
else:
|
||||
lora_sd[key] = value
|
||||
|
||||
# todo process state dict before loading for models that need it
|
||||
if self.lora is not None:
|
||||
self.lora.load_weights(lora_sd)
|
||||
self.mean_flow_timestep_embedder.load_state_dict(
|
||||
mean_flow_embedder_sd, strict=False
|
||||
)
|
||||
|
||||
def get_state_dict(self):
|
||||
if self.lora is not None:
|
||||
lora_sd = self.lora.get_state_dict(dtype=torch.float32)
|
||||
else:
|
||||
lora_sd = {}
|
||||
# todo make sure we match loras elseware.
|
||||
mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict()
|
||||
for key, value in mean_flow_embedder_sd.items():
|
||||
lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value
|
||||
return lora_sd
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
Reference in New Issue
Block a user