WIP on mean flow loss. Still a WIP.

This commit is contained in:
Jaret Burkett
2025-06-12 08:00:51 -06:00
parent cf11f128b9
commit fc83eb7691
6 changed files with 465 additions and 62 deletions

View File

@@ -413,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)

View File

@@ -4,6 +4,7 @@ from functools import partial
from typing import Optional
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0)
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)

View File

@@ -16,6 +16,7 @@ from diffusers import (
LCMScheduler,
FlowMatchEulerDiscreteScheduler,
)
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
@@ -159,6 +160,8 @@ def get_sampler(
scheduler_cls = LCMScheduler
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
elif sampler == "mean_flow":
scheduler_cls = MeanFlowScheduler
elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = copy.deepcopy(flux_config)

View File

@@ -0,0 +1,93 @@
from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
from dataclasses import dataclass
from typing import Optional, Tuple
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
pass
def get_weights_for_timesteps(
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
weights = 1.0
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype,
)
return weights
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
return noisy_model_input
def scale_model_input(
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, **kwargs):
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs: Optional[dict],
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
# single euler step (Eq. 5 ⇒ x₀ = x₁ uθ)
output = sample - model_output
if not return_dict:
return (output,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)