from typing import Union from diffusers import FlowMatchEulerDiscreteScheduler import torch class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: sigmas = self.sigmas.to(device=device, dtype=dtype) schedule_timesteps = self.timesteps.to(device) timesteps = timesteps.to(device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 ## Add noise according to flow matching. ## zt = (1 - texp) * x + texp * z1 # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise n_dim = original_samples.ndim sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample