diff --git a/extensions_built_in/diffusion_models/omnigen2/__init__.py b/extensions_built_in/diffusion_models/omnigen2/__init__.py index 5b20b387..ee37a1f0 100644 --- a/extensions_built_in/diffusion_models/omnigen2/__init__.py +++ b/extensions_built_in/diffusion_models/omnigen2/__init__.py @@ -239,7 +239,10 @@ class OmniGen2Model(BaseModel): **kwargs ): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + try: + timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + except Exception as e: + pass # optional_kwargs = {} # if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()): diff --git a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py index 45e509bc..651f1add 100644 --- a/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py +++ b/extensions_built_in/diffusion_models/omnigen2/src/pipelines/omnigen2/pipeline_omnigen2.py @@ -305,14 +305,14 @@ class OmniGen2Pipeline(DiffusionPipeline): ) text_input_ids = text_inputs.input_ids.to(device) - untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + # untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because Gemma can only handle sequences up to" - f" {max_sequence_length} tokens: {removed_text}" - ) + # if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + # removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + # logger.warning( + # "The following part of your input was truncated because Gemma can only handle sequences up to" + # f" {max_sequence_length} tokens: {removed_text}" + # ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.mllm( diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b4439f30..ca808823 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -459,15 +459,33 @@ class SDTrainer(BaseSDTrainProcess): stepped_latents = torch.cat(stepped_chunks, dim=0) stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) + # resize to half the size of the latents + stepped_latents_half = torch.nn.functional.interpolate( + stepped_latents, + size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2), + mode='bilinear', + align_corners=False + ) pred_features = self.dfe(stepped_latents.float()) + pred_features_half = self.dfe(stepped_latents_half.float()) with torch.no_grad(): target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32)) + batch_latents_half = torch.nn.functional.interpolate( + batch.latents.to(self.device_torch, dtype=torch.float32), + size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2), + mode='bilinear', + align_corners=False + ) + target_features_half = self.dfe(batch_latents_half) # scale dfe so it is weaker at higher noise levels dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ self.train_config.diffusion_feature_extractor_weight * dfe_scaler - additional_loss += dfe_loss.mean() + + dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \ + self.train_config.diffusion_feature_extractor_weight * dfe_scaler + additional_loss += dfe_loss.mean() + dfe_loss_half.mean() elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target @@ -735,136 +753,106 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): - # ------------------------------------------------------------------ - # “Slow” Mean-Flow loss – finite-difference version - # (avoids JVP / double-backprop issues with Flash-Attention) - # ------------------------------------------------------------------ - dtype = get_torch_dtype(self.train_config.dtype) - total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000 - base_eps = 1e-3 # this is one step when multiplied by 1000 + dtype = get_torch_dtype(self.train_config.dtype) + total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000 + base_eps = 1e-3 + min_time_gap = 1e-2 with torch.no_grad(): num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps batch_size = batch.latents.shape[0] timestep_t_list = [] timestep_r_list = [] + for i in range(batch_size): t1 = random.randint(0, num_train_timesteps - 1) t2 = random.randint(0, num_train_timesteps - 1) t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] - if (t_t - t_r).item() < base_eps * 1000: - # we need to ensure the time gap is wider than the epsilon(one step) - scaled_eps = base_eps * 1000 - if t_t.item() + scaled_eps > 1000: - t_r = t_r - scaled_eps + if (t_t - t_r).item() < min_time_gap * 1000: + scaled_time_gap = min_time_gap * 1000 + if t_t.item() + scaled_time_gap > 1000: + t_r = t_r - scaled_time_gap else: - t_t = t_t + scaled_eps + t_t = t_t + scaled_time_gap timestep_t_list.append(t_t) timestep_r_list.append(t_r) - eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps + timesteps_t = torch.stack(timestep_t_list, dim=0).float() timesteps_r = torch.stack(timestep_r_list, dim=0).float() - # fractions in [0,1] - t_frac = timesteps_t / total_steps - r_frac = timesteps_r / total_steps + t_frac = timesteps_t / total_steps # [0,1] + r_frac = timesteps_r / total_steps # [0,1] - # 2) construct data points - latents_clean = batch.latents.to(dtype) - noise_sample = noise.to(dtype) + latents_clean = batch.latents.to(dtype) + noise_sample = noise.to(dtype) - lerp_vector = noise_sample * t_frac[:, None, None, None] \ - + latents_clean * (1.0 - t_frac[:, None, None, None]) + lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None] - if hasattr(self.sd, 'get_loss_target'): - instantaneous_vector = self.sd.get_loss_target( - noise=noise_sample, - batch=batch, - timesteps=timesteps, - ).detach() - else: - instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W) + eps = base_eps - # 3) finite-difference JVP approximation (bump z **and** t) - # eps_base, eps_jitter = 1e-3, 1e-4 - # eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4) - jitter = 1e-4 - eps = value_map( - torch.rand_like(t_frac), - 0.0, - 1.0, - base_eps, - base_eps + jitter - ) - # eps = (t_frac - r_frac) / 2 + # concatenate timesteps as input for u(z, r, t) + timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps - # eps = 1e-3 - # primary prediction (needs grad) - mean_vec_pred = self.predict_noise( - noisy_latents=lerp_vector, - timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps, + # model predicts u(z, r, t) + u_pred = self.predict_noise( + noisy_latents=lerp_vector.to(dtype), + timesteps=timesteps_cat.to(dtype), conditional_embeds=conditional_embeds, unconditional_embeds=unconditional_embeds, batch=batch, **pred_kwargs ) - # secondary prediction: bump both latent and timestep by ε with torch.no_grad(): - # lerp_perturbed = lerp_vector + eps * instantaneous_vector - t_frac_plus_eps = t_frac + eps # bump time fraction - lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \ - + latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0) + lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None] + timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps - f_x_plus_eps_v = self.predict_noise( - noisy_latents=lerp_perturbed, - timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps, + u_perturbed = self.predict_noise( + noisy_latents=lerp_perturbed.to(dtype), + timesteps=timesteps_cat_perturbed.to(dtype), conditional_embeds=conditional_embeds, unconditional_embeds=unconditional_embeds, batch=batch, **pred_kwargs ) - # finite-difference JVP: (f(x+εv) – f(x)) / ε - mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps - # time_gap = (t_frac - r_frac)[:, None, None, None] - # mean_vec_scaler = time_gap / eps - # mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler - # mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11 + # compute du/dt via finite difference (detached) + du_dt = (u_perturbed - u_pred).detach() / eps + # du_dt = (u_perturbed - u_pred).detach() + du_dt = du_dt.to(dtype) + + + time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype) + time_gap.clamp(min=1e-4) + u_shifted = u_pred + time_gap * du_dt + # u_shifted = u_pred + du_dt / time_gap + # u_shifted = u_pred - # 4) regression target for the mean vector - time_gap = (t_frac - r_frac)[:, None, None, None] - mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad - # mean_vec_target = instantaneous_vector - mean_vec_grad + # a step is done like this: + # stepped_latent = model_input + (timestep_next - timestep) * model_output + + # flow target velocity + # v_target = (noise_sample - latents_clean) / time_gap + # flux predicts opposite of velocity, so we need to invert it + v_target = (latents_clean - noise_sample) / time_gap - # # 5) MSE loss - # loss = torch.nn.functional.mse_loss( - # mean_vec_pred.float(), - # mean_vec_target.float() - # ) - # return loss - # 5) MSE loss + # compute loss loss = torch.nn.functional.mse_loss( - mean_vec_pred.float(), - mean_vec_target.float(), + u_shifted.float(), + v_target.float(), reduction='none' ) + with torch.no_grad(): pure_loss = loss.mean().detach() - # add grad to pure_loss so it can be backwards without issues pure_loss.requires_grad_(True) - # normalize the loss per batch element to 1.0 - # this method has large loss swings that can hurt the model. This method will prevent that - # with torch.no_grad(): - # loss_mean = loss.mean([1, 2, 3], keepdim=True) - # loss = loss / loss_mean + loss = loss.mean() - - # backward the pure loss for logging + if loss.item() > 1e3: + pass self.accelerator.backward(loss) - - # return the real loss for logging return pure_loss diff --git a/toolkit/models/mean_flow_adapter.py b/toolkit/models/mean_flow_adapter.py index 1c5a6c15..2eb182c4 100644 --- a/toolkit/models/mean_flow_adapter.py +++ b/toolkit/models/mean_flow_adapter.py @@ -1,7 +1,7 @@ import inspect import weakref import torch -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from toolkit.lora_special import LoRASpecialNetwork from diffusers import FluxTransformer2DModel from diffusers.models.embeddings import ( @@ -15,6 +15,7 @@ if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig from toolkit.custom_adapter import CustomAdapter + from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel def mean_flow_time_text_embed_forward( @@ -60,7 +61,7 @@ def mean_flow_time_text_guidance_embed_forward( # make zero timestep ending if none is passed if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: timestep = torch.cat( - [timestep, torch.zeros_like(timestep)], dim=0 + [timestep, torch.ones_like(timestep)], dim=0 ) # timestep - 0 (final timestep) == same as start timestep timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder( @@ -111,6 +112,38 @@ def convert_flux_to_mean_flow( ) ) +def mean_flow_omnigen2_time_text_embed_forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() + if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]: + timestep = torch.cat( + [timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps + ) + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + + # mean flow stuff + if mean_flow_adapter.is_active: + # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps + orig_dtype = time_embed.dtype + time_embed = time_embed.to(torch.float32) + time_embed_start, time_embed_end = time_embed.chunk(2, dim=0) + time_embed = mean_flow_adapter.mean_flow_timestep_embedder( + torch.cat([time_embed_start, time_embed_end], dim=-1) + ) + time_embed = time_embed.to(orig_dtype) + + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed + + +def convert_omnigen2_to_mean_flow( + transformer: 'OmniGen2Transformer2DModel', +): + transformer.time_caption_embed.forward = partial( + mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed + ) class MeanFlowAdapter(torch.nn.Module): def __init__( @@ -193,6 +226,13 @@ class MeanFlowAdapter(torch.nn.Module): * transformer.config.attention_head_dim ) convert_flux_to_mean_flow(transformer) + + elif self.model_config.arch in ["omnigen2"]: + transformer: 'OmniGen2Transformer2DModel' = sd.unet + emb_dim = ( + 1024 + ) + convert_omnigen2_to_mean_flow(transformer) else: raise ValueError(f"Unsupported architecture: {self.model_config.arch}") @@ -212,6 +252,8 @@ class MeanFlowAdapter(torch.nn.Module): # add our adapter as a weak ref if self.model_config.arch in ["flux", "flex2", "flex2"]: sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) + elif self.model_config.arch in ["omnigen2"]: + sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self) def get_params(self): if self.lora is not None: