From 1c2b7298dd29a70294ca28b0eda4c241af6b0ad9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 16 Jun 2025 07:17:35 -0600 Subject: [PATCH] More work on mean flow loss. Moved it to an adapter. Still not functioning properly though. --- extensions_built_in/sd_trainer/SDTrainer.py | 101 ------- jobs/process/BaseSDTrainProcess.py | 6 +- toolkit/custom_adapter.py | 40 ++- toolkit/models/control_lora_adapter.py | 2 +- toolkit/models/flux.py | 57 ---- toolkit/models/mean_flow_adapter.py | 282 ++++++++++++++++++++ 6 files changed, 323 insertions(+), 165 deletions(-) create mode 100644 toolkit/models/mean_flow_adapter.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f3813133..a0de8cd1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss import torch.nn.functional as F -from toolkit.models.flux import convert_flux_to_mean_flow def flush(): @@ -135,9 +134,6 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() - if self.train_config.loss_type == "mean_flow": - # todo handle non flux models - convert_flux_to_mean_flow(self.sd.unet) if self.train_config.do_prior_divergence: self.do_prior_prediction = True @@ -634,102 +630,6 @@ class SDTrainer(BaseSDTrainProcess): return loss - # ------------------------------------------------------------------ - # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative - # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) - # This version avoids jvp / double-back-prop issues with Flash-Attention - # adapted from the work of lodestonerock - # ------------------------------------------------------------------ - def get_mean_flow_loss_wip( - self, - noisy_latents: torch.Tensor, - conditional_embeds: PromptEmbeds, - match_adapter_assist: bool, - network_weight_list: list, - timesteps: torch.Tensor, - pred_kwargs: dict, - batch: 'DataLoaderBatchDTO', - noise: torch.Tensor, - unconditional_embeds: Optional[PromptEmbeds] = None, - **kwargs - ): - batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) - - - time_end = timesteps.float() / 1000 - # for timestep_r, we need values from timestep_end to 0.0 randomly - time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end - - # time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype) - # Compute noised data points - # lerp_vector = noisy_latents - # compute instantaneous vector - instantaneous_vector = noise - batch_latents - - # finite difference method - epsilon_fd = 1e-3 - jitter_std = 1e-4 - epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std - epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4) - - # f(x + epsilon * v) for the primal (we backprop through here) - # mean_vec_val_pred = self.forward(lerp_vector, class_label) - mean_vec_val_pred = self.predict_noise( - noisy_latents=noisy_latents, - timesteps=torch.cat([time_end, time_origin], dim=0) * 1000, - conditional_embeds=conditional_embeds, - unconditional_embeds=unconditional_embeds, - batch=batch, - **pred_kwargs - ) - - with torch.no_grad(): - perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0) - # intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE! - perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector - # f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label) - f_x_plus_eps_v = self.predict_noise( - noisy_latents=perturbed_lerp_vector, - timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000, - conditional_embeds=conditional_embeds, - unconditional_embeds=unconditional_embeds, - batch=batch, - **pred_kwargs - ) - - # JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon - mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered - mean_vec_grad = mean_vec_grad_fd - - - # calculate the regression target the mean vector - time_difference_broadcast = (time_end - time_origin)[:, None, None, None] - mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad - - # 5) MSE loss - loss = torch.nn.functional.mse_loss( - mean_vec_val_pred.float(), - mean_vec_target.float(), - reduction='none' - ) - with torch.no_grad(): - pure_loss = loss.mean().detach() - # add grad to pure_loss so it can be backwards without issues - pure_loss.requires_grad_(True) - # normalize the loss per batch element to 1.0 - # this method has large loss swings that can hurt the model. This method will prevent that - with torch.no_grad(): - loss_mean = loss.mean([1, 2, 3], keepdim=True) - loss = loss / loss_mean - loss = loss.mean() - - # backward the pure loss for logging - self.accelerator.backward(loss) - - # return the real loss for logging - return pure_loss - - # ------------------------------------------------------------------ # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) @@ -811,7 +711,6 @@ class SDTrainer(BaseSDTrainProcess): base_eps, base_eps + jitter ) - # eps = (t_frac - r_frac) / 2 # eps = 1e-3 # primary prediction (needs grad) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 064e89fc..1e0393c0 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -575,10 +575,8 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save = False if self.adapter_config.train_only_image_encoder: direct_save = True - if self.adapter_config.type == 'redux': - direct_save = True - if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']: - direct_save = True + elif isinstance(self.adapter, CustomAdapter): + direct_save = self.adapter.do_direct_save save_ip_adapter_from_diffusers( state_dict, output_file=file_path, diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index a9ec5d10..cc58c5b5 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.control_lora_adapter import ControlLoraAdapter +from toolkit.models.mean_flow_adapter import MeanFlowAdapter from toolkit.models.i2v_adapter import I2VAdapter from toolkit.models.subpixel_adapter import SubpixelAdapter from toolkit.models.ilora import InstantLoRAModule @@ -98,6 +99,7 @@ class CustomAdapter(torch.nn.Module): self.single_value_adapter: SingleValueAdapter = None self.redux_adapter: ReduxImageEncoder = None self.control_lora: ControlLoraAdapter = None + self.mean_flow_adapter: MeanFlowAdapter = None self.subpixel_adapter: SubpixelAdapter = None self.i2v_adapter: I2VAdapter = None @@ -125,6 +127,16 @@ class CustomAdapter(torch.nn.Module): dtype=self.sd_ref().dtype, ) self.load_state_dict(loaded_state_dict, strict=False) + + @property + def do_direct_save(self): + # some adapters save their weights directly, others like ip adapters split the state dict + if self.config.train_only_image_encoder: + return True + if self.config.type in ['control_lora', 'subpixel', 'i2v', 'redux', 'mean_flow']: + return True + return False + def setup_adapter(self): torch_dtype = get_torch_dtype(self.sd_ref().dtype) @@ -245,6 +257,13 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'redux': vision_hidden_size = self.vision_encoder.config.hidden_size self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype) + elif self.adapter_type == 'mean_flow': + self.mean_flow_adapter = MeanFlowAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) elif self.adapter_type == 'control_lora': self.control_lora = ControlLoraAdapter( self, @@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]: + if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel", "mean_flow"]: return if self.config.type == 'photo_maker': try: @@ -528,6 +547,14 @@ class CustomAdapter(torch.nn.Module): new_dict[k + '.' + k2] = v2 self.control_lora.load_weights(new_dict, strict=strict) + if self.adapter_type == 'mean_flow': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.mean_flow_adapter.load_weights(new_dict, strict=strict) + if self.adapter_type == 'i2v': # state dict is seperated. so recombine it new_dict = {} @@ -599,6 +626,11 @@ class CustomAdapter(torch.nn.Module): for k, v in d.items(): state_dict[k] = v return state_dict + elif self.adapter_type == 'mean_flow': + d = self.mean_flow_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict elif self.adapter_type == 'i2v': d = self.i2v_adapter.get_state_dict() for k, v in d.items(): @@ -757,7 +789,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']: + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v', 'mean_flow']: return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -1319,6 +1351,10 @@ class CustomAdapter(torch.nn.Module): param_list = self.control_lora.get_params() for param in param_list: yield param + elif self.config.type == 'mean_flow': + param_list = self.mean_flow_adapter.get_params() + for param in param_list: + yield param elif self.config.type == 'i2v': param_list = self.i2v_adapter.get_params() for param in param_list: diff --git a/toolkit/models/control_lora_adapter.py b/toolkit/models/control_lora_adapter.py index 3588302d..38147ea9 100644 --- a/toolkit/models/control_lora_adapter.py +++ b/toolkit/models/control_lora_adapter.py @@ -135,7 +135,7 @@ class ControlLoraAdapter(torch.nn.Module): network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs if hasattr(sd, 'target_lora_modules'): - network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + network_kwargs['target_lin_modules'] = sd.target_lora_modules if 'ignore_if_contains' not in network_kwargs: network_kwargs['ignore_if_contains'] = [] diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 5d3064f4..0241ce2f 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux( transformer._pre_gpu_split_to = transformer.to transformer.to = partial(new_device_to, transformer) - -def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection): - # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0]: - timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep - - timesteps_proj = self.time_proj(timestep) - timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0) - - timesteps_emb = timesteps_emb_start + timesteps_emb_end - - pooled_projections = self.text_embedder(pooled_projection) - - conditioning = timesteps_emb + pooled_projections - - return conditioning - -def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection): - # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0]: - timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) - - time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb - - pooled_projections = self.text_embedder(pooled_projection) - conditioning = time_guidance_emb + pooled_projections - - return conditioning - - -def convert_flux_to_mean_flow( - transformer: FluxTransformer2DModel, -): - if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): - transformer.time_text_embed.forward = partial( - mean_flow_time_text_embed_forward, transformer.time_text_embed - ) - elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings): - transformer.time_text_embed.forward = partial( - mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed - ) - else: - raise ValueError( - "Unsupported time_text_embed type: {}".format( - type(transformer.time_text_embed) - ) - ) - \ No newline at end of file diff --git a/toolkit/models/mean_flow_adapter.py b/toolkit/models/mean_flow_adapter.py new file mode 100644 index 00000000..1c5a6c15 --- /dev/null +++ b/toolkit/models/mean_flow_adapter.py @@ -0,0 +1,282 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import ( + CombinedTimestepTextProjEmbeddings, + CombinedTimestepGuidanceTextProjEmbeddings, +) +from functools import partial + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + +def mean_flow_time_text_embed_forward( + self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.zeros_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +def mean_flow_time_text_guidance_embed_forward( + self: CombinedTimestepGuidanceTextProjEmbeddings, + timestep, + guidance, + pooled_projection, +): + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + # make zero timestep ending if none is passed + if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat( + [timestep, torch.zeros_like(timestep)], dim=0 + ) # timestep - 0 (final timestep) == same as start timestep + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + guidance_proj.to(dtype=pooled_projection.dtype) + ) # (N, D) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = timesteps_emb.dtype + timesteps_emb = timesteps_emb.to(torch.float32) + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) + ) + timesteps_emb = timesteps_emb.to(orig_dtype) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +def convert_flux_to_mean_flow( + transformer: FluxTransformer2DModel, +): + if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_embed_forward, transformer.time_text_embed + ) + elif isinstance( + transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings + ): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed + ) + else: + raise ValueError( + "Unsupported time_text_embed type: {}".format( + type(transformer.time_text_embed) + ) + ) + + +class MeanFlowAdapter(torch.nn.Module): + def __init__( + self, + adapter: "CustomAdapter", + sd: "StableDiffusion", + config: "AdapterConfig", + train_config: "TrainConfig", + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.lora = None + + if self.network_config is not None: + network_kwargs = ( + {} + if self.network_config.network_kwargs is None + else self.network_config.network_kwargs + ) + if hasattr(sd, "target_lora_modules"): + network_kwargs["target_lin_modules"] = sd.target_lora_modules + + if "ignore_if_contains" not in network_kwargs: + network_kwargs["ignore_if_contains"] = [] + + self.lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs, + ) + self.lora.force_to(self.device_torch, dtype=torch.float32) + self.lora._update_torch_multiplier() + self.lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet, + ) + self.lora.can_merge_in = False + self.lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.lora.enable_gradient_checkpointing() + + emb_dim = None + if self.model_config.arch in ["flux", "flex2", "flex2"]: + transformer: FluxTransformer2DModel = sd.unet + emb_dim = ( + transformer.config.num_attention_heads + * transformer.config.attention_head_dim + ) + convert_flux_to_mean_flow(transformer) + else: + raise ValueError(f"Unsupported architecture: {self.model_config.arch}") + + self.mean_flow_timestep_embedder = torch.nn.Linear( + emb_dim * 2, + emb_dim, + ) + + # make the model function as before adding this adapter by initializing the weights + with torch.no_grad(): + self.mean_flow_timestep_embedder.weight.zero_() + self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim) + self.mean_flow_timestep_embedder.bias.zero_() + + self.mean_flow_timestep_embedder.to(self.device_torch) + + # add our adapter as a weak ref + if self.model_config.arch in ["flux", "flex2", "flex2"]: + sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) + + def get_params(self): + if self.lora is not None: + config = { + "text_encoder_lr": self.train_config.lr, + "unet_lr": self.train_config.lr, + } + sig = inspect.signature(self.lora.prepare_optimizer_params) + if "default_lr" in sig.parameters: + config["default_lr"] = self.train_config.lr + if "learning_rate" in sig.parameters: + config["learning_rate"] = self.train_config.lr + params_net = self.lora.prepare_optimizer_params(**config) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.mean_flow_timestep_embedder.to(torch.float32) + self.mean_flow_timestep_embedder.requires_grad = True + self.mean_flow_timestep_embedder.train() + + params += list(self.mean_flow_timestep_embedder.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + mean_flow_embedder_sd = {} + for key, value in state_dict.items(): + if "mean_flow_timestep_embedder" in key: + new_key = key.replace("transformer.mean_flow_timestep_embedder.", "") + mean_flow_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading for models that need it + if self.lora is not None: + self.lora.load_weights(lora_sd) + self.mean_flow_timestep_embedder.load_state_dict( + mean_flow_embedder_sd, strict=False + ) + + def get_state_dict(self): + if self.lora is not None: + lora_sd = self.lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict() + for key, value in mean_flow_embedder_sd.items(): + lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active