From fc83eb76912a93825d06f407a3ac6f9f16c65116 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 12 Jun 2025 08:00:51 -0600 Subject: [PATCH] WIP on mean flow loss. Still a WIP. --- build_and_push_docker_dev | 2 +- extensions_built_in/sd_trainer/SDTrainer.py | 372 ++++++++++++++++---- toolkit/config_modules.py | 2 +- toolkit/models/flux.py | 55 +++ toolkit/sampler.py | 3 + toolkit/samplers/mean_flow_scheduler.py | 93 +++++ 6 files changed, 465 insertions(+), 62 deletions(-) create mode 100644 toolkit/samplers/mean_flow_scheduler.py diff --git a/build_and_push_docker_dev b/build_and_push_docker_dev index 6a1a17d0..9098d8cd 100644 --- a/build_and_push_docker_dev +++ b/build_and_push_docker_dev @@ -4,7 +4,7 @@ VERSION=dev GIT_COMMIT=dev echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." -echo "Building version: $VERSION and latest" +echo "Building version: $VERSION" # wait 2 seconds sleep 2 diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5d0105f2..6e1daf39 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss import torch.nn.functional as F +from toolkit.models.flux import convert_flux_to_mean_flow def flush(): @@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() + if self.train_config.timestep_type == "mean_flow": + # todo handle non flux models + convert_flux_to_mean_flow(self.sd.transformer) if self.train_config.do_prior_divergence: self.do_prior_prediction = True @@ -595,19 +599,6 @@ class SDTrainer(BaseSDTrainProcess): return loss + additional_loss - - def get_diff_output_preservation_loss( - self, - noise_pred: torch.Tensor, - noise: torch.Tensor, - noisy_latents: torch.Tensor, - timesteps: torch.Tensor, - batch: 'DataLoaderBatchDTO', - mask_multiplier: Union[torch.Tensor, float] = 1.0, - prior_pred: Union[torch.Tensor, None] = None, - **kwargs - ): - loss_target = self.train_config.loss_target def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch @@ -641,6 +632,254 @@ class SDTrainer(BaseSDTrainProcess): ) return loss + + + # ------------------------------------------------------------------ + # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative + # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) + # This version avoids jvp / double-back-prop issues with Flash-Attention + # adapted from the work of lodestonerock + # ------------------------------------------------------------------ + def get_mean_flow_loss_wip( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + + + time_end = timesteps.float() / 1000 + # for timestep_r, we need values from timestep_end to 0.0 randomly + time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end + + # time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype) + # Compute noised data points + # lerp_vector = noisy_latents + # compute instantaneous vector + instantaneous_vector = noise - batch_latents + + # finite difference method + epsilon_fd = 1e-3 + jitter_std = 1e-4 + epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std + epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4) + + # f(x + epsilon * v) for the primal (we backprop through here) + # mean_vec_val_pred = self.forward(lerp_vector, class_label) + mean_vec_val_pred = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=torch.cat([time_end, time_origin], dim=0) * 1000, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + with torch.no_grad(): + perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0) + # intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE! + perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector + # f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label) + f_x_plus_eps_v = self.predict_noise( + noisy_latents=perturbed_lerp_vector, + timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon + mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered + mean_vec_grad = mean_vec_grad_fd + + + # calculate the regression target the mean vector + time_difference_broadcast = (time_end - time_origin)[:, None, None, None] + mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad + + # 5) MSE loss + loss = torch.nn.functional.mse_loss( + mean_vec_val_pred.float(), + mean_vec_target.float(), + reduction='none' + ) + with torch.no_grad(): + pure_loss = loss.mean().detach() + # add grad to pure_loss so it can be backwards without issues + pure_loss.requires_grad_(True) + # normalize the loss per batch element to 1.0 + # this method has large loss swings that can hurt the model. This method will prevent that + with torch.no_grad(): + loss_mean = loss.mean([1, 2, 3], keepdim=True) + loss = loss / loss_mean + loss = loss.mean() + + # backward the pure loss for logging + self.accelerator.backward(loss) + + # return the real loss for logging + return pure_loss + + + # ------------------------------------------------------------------ + # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative + # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) + # This version avoids jvp / double-back-prop issues with Flash-Attention + # adapted from the work of lodestonerock + # ------------------------------------------------------------------ + def get_mean_flow_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + # ------------------------------------------------------------------ + # “Slow” Mean-Flow loss – finite-difference version + # (avoids JVP / double-backprop issues with Flash-Attention) + # ------------------------------------------------------------------ + dtype = get_torch_dtype(self.train_config.dtype) + total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000 + base_eps = 1e-3 # this is one step when multiplied by 1000 + + with torch.no_grad(): + num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps + batch_size = batch.latents.shape[0] + timestep_t_list = [] + timestep_r_list = [] + for i in range(batch_size): + t1 = random.randint(0, num_train_timesteps - 1) + t2 = random.randint(0, num_train_timesteps - 1) + t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] + t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] + if (t_t - t_r).item() < base_eps * 1000: + # we need to ensure the time gap is wider than the epsilon(one step) + scaled_eps = base_eps * 1000 + if t_t.item() + scaled_eps > 1000: + t_r = t_r - scaled_eps + else: + t_t = t_t + scaled_eps + timestep_t_list.append(t_t) + timestep_r_list.append(t_r) + eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps + timesteps_t = torch.stack(timestep_t_list, dim=0).float() + timesteps_r = torch.stack(timestep_r_list, dim=0).float() + + # fractions in [0,1] + t_frac = timesteps_t / total_steps + r_frac = timesteps_r / total_steps + + # 2) construct data points + latents_clean = batch.latents.to(dtype) + noise_sample = noise.to(dtype) + + lerp_vector = noise_sample * t_frac[:, None, None, None] \ + + latents_clean * (1.0 - t_frac[:, None, None, None]) + + if hasattr(self.sd, 'get_loss_target'): + instantaneous_vector = self.sd.get_loss_target( + noise=noise_sample, + batch=batch, + timesteps=timesteps, + ).detach() + else: + instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W) + + # 3) finite-difference JVP approximation (bump z **and** t) + # eps_base, eps_jitter = 1e-3, 1e-4 + # eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4) + jitter = 1e-4 + eps = value_map( + torch.rand_like(t_frac), + 0.0, + 1.0, + base_eps, + base_eps + jitter + ) + + # eps = 1e-3 + # primary prediction (needs grad) + mean_vec_pred = self.predict_noise( + noisy_latents=lerp_vector, + timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # secondary prediction: bump both latent and timestep by ε + with torch.no_grad(): + # lerp_perturbed = lerp_vector + eps * instantaneous_vector + t_frac_plus_eps = t_frac + eps # bump time fraction + lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \ + + latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + + f_x_plus_eps_v = self.predict_noise( + noisy_latents=lerp_perturbed, + timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps, + conditional_embeds=conditional_embeds, + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + + # finite-difference JVP: (f(x+εv) – f(x)) / ε + mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps + # time_gap = (t_frac - r_frac)[:, None, None, None] + # mean_vec_scaler = time_gap / eps + # mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler + # mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11 + + # 4) regression target for the mean vector + time_gap = (t_frac - r_frac)[:, None, None, None] + mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad + # mean_vec_target = instantaneous_vector - mean_vec_grad + + # # 5) MSE loss + # loss = torch.nn.functional.mse_loss( + # mean_vec_pred.float(), + # mean_vec_target.float() + # ) + # return loss + # 5) MSE loss + loss = torch.nn.functional.mse_loss( + mean_vec_pred.float(), + mean_vec_target.float(), + reduction='none' + ) + with torch.no_grad(): + pure_loss = loss.mean().detach() + # add grad to pure_loss so it can be backwards without issues + pure_loss.requires_grad_(True) + # normalize the loss per batch element to 1.0 + # this method has large loss swings that can hurt the model. This method will prevent that + # with torch.no_grad(): + # loss_mean = loss.mean([1, 2, 3], keepdim=True) + # loss = loss / loss_mean + loss = loss.mean() + + # backward the pure loss for logging + self.accelerator.backward(loss) + + # return the real loss for logging + return pure_loss + def get_prior_prediction( @@ -1495,6 +1734,52 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample self.before_unet_predict() + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + with self.timer('condition_noisy_latents'): + # do it for the model + noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) + if self.adapter and isinstance(self.adapter, CustomAdapter): + noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: # do guided loss @@ -1511,54 +1796,21 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier=mask_multiplier, prior_pred=prior_pred, ) - + + elif self.train_config.loss_type == 'mean_flow': + loss = self.get_mean_flow_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + prior_pred=prior_pred, + ) else: - if unconditional_embeds is not None: - unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() - with self.timer('condition_noisy_latents'): - # do it for the model - noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) - if self.adapter and isinstance(self.adapter, CustomAdapter): - noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) - - if self.train_config.timestep_type == 'next_sample': - with self.timer('next_sample_step'): - with torch.no_grad(): - - stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] - stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] - stepped_timesteps = torch.stack(stepped_timesteps, dim=0) - - # do a sample at the current timestep and step it, then determine new noise - next_sample_pred = self.predict_noise( - noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), - timesteps=timesteps, - conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), - unconditional_embeds=unconditional_embeds, - batch=batch, - **pred_kwargs - ) - stepped_latents = self.sd.step_scheduler( - next_sample_pred, - noisy_latents, - timesteps, - self.sd.noise_scheduler - ) - # stepped latents is our new noisy latents. Now we need to determine noise in the current sample - noisy_latents = stepped_latents - original_samples = batch.latents.to(self.device_torch, dtype=dtype) - # todo calc next timestep, for now this may work as it - t_01 = (stepped_timesteps / 1000).to(original_samples.device) - if len(stepped_latents.shape) == 4: - t_01 = t_01.view(-1, 1, 1, 1) - elif len(stepped_latents.shape) == 5: - t_01 = t_01.view(-1, 1, 1, 1, 1) - else: - raise ValueError("Unknown stepped latents shape", stepped_latents.shape) - next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 - noise = next_sample_noise - timesteps = stepped_timesteps - with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 0e8b8af0..c1fd396c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -413,7 +413,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index c2fb5ac9..42194179 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -4,6 +4,7 @@ from functools import partial from typing import Optional import torch from diffusers import FluxTransformer2DModel +from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): @@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux( transformer._pre_gpu_split_to = transformer.to transformer.to = partial(new_device_to, transformer) + + +def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection): + # make zero timestep ending if none is passed + if timestep.shape[0] == pooled_projection.shape[0] // 2: + timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep + + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + +def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection): + # make zero timestep ending if none is passed + if timestep.shape[0] == pooled_projection.shape[0] // 2: + timestep = torch.cat([timestep, timestep], dim=0) + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) + + time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +def convert_flux_to_mean_flow( + transformer: FluxTransformer2DModel, +): + if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_embed_forward, transformer.time_text_embed + ) + elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings): + transformer.time_text_embed.forward = partial( + mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed + ) + else: + raise ValueError( + "Unsupported time_text_embed type: {}".format( + type(transformer.time_text_embed) + ) + ) + \ No newline at end of file diff --git a/toolkit/sampler.py b/toolkit/sampler.py index 47fdc205..c36af3de 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -16,6 +16,7 @@ from diffusers import ( LCMScheduler, FlowMatchEulerDiscreteScheduler, ) +from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler @@ -159,6 +160,8 @@ def get_sampler( scheduler_cls = LCMScheduler elif sampler == "custom_lcm": scheduler_cls = CustomLCMScheduler + elif sampler == "mean_flow": + scheduler_cls = MeanFlowScheduler elif sampler == "flowmatch": scheduler_cls = CustomFlowMatchEulerDiscreteScheduler config_to_use = copy.deepcopy(flux_config) diff --git a/toolkit/samplers/mean_flow_scheduler.py b/toolkit/samplers/mean_flow_scheduler.py new file mode 100644 index 00000000..6ce22ba1 --- /dev/null +++ b/toolkit/samplers/mean_flow_scheduler.py @@ -0,0 +1,93 @@ +from typing import Union +from diffusers import FlowMatchEulerDiscreteScheduler +import torch +from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme + +from dataclasses import dataclass +from typing import Optional, Tuple +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_noise_sigma = 1.0 + self.timestep_type = "linear" + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # Create linear timesteps from 1000 to 1 + timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu") + + self.linear_timesteps = timesteps + pass + + def get_weights_for_timesteps( + self, timesteps: torch.Tensor, v2=False, timestep_type="linear" + ) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + weights = 1.0 + + # Get the weights for the timesteps + if timestep_type == "weighted": + weights = torch.tensor( + [default_weighing_scheme[i] for i in step_indices], + device=timesteps.device, + dtype=timesteps.dtype, + ) + + return weights + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise + return noisy_model_input + + def scale_model_input( + self, sample: torch.Tensor, timestep: Union[float, torch.Tensor] + ) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, **kwargs): + timesteps = torch.linspace(1000, 1, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + return_dict: bool = True, + **kwargs: Optional[dict], + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + + # single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ) + output = sample - model_output + if not return_dict: + return (output,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)