mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
WIP on mean flow loss. Still a WIP.
This commit is contained in:
@@ -413,7 +413,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
|
||||
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import Optional
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
|
||||
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0)
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
|
||||
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from diffusers import (
|
||||
LCMScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -159,6 +160,8 @@ def get_sampler(
|
||||
scheduler_cls = LCMScheduler
|
||||
elif sampler == "custom_lcm":
|
||||
scheduler_cls = CustomLCMScheduler
|
||||
elif sampler == "mean_flow":
|
||||
scheduler_cls = MeanFlowScheduler
|
||||
elif sampler == "flowmatch":
|
||||
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
||||
config_to_use = copy.deepcopy(flux_config)
|
||||
|
||||
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Union
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_noise_sigma = 1.0
|
||||
self.timestep_type = "linear"
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# Create linear timesteps from 1000 to 1
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(
|
||||
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
|
||||
) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
weights = 1.0
|
||||
|
||||
# Get the weights for the timesteps
|
||||
if timestep_type == "weighted":
|
||||
weights = torch.tensor(
|
||||
[default_weighing_scheme[i] for i in step_indices],
|
||||
device=timesteps.device,
|
||||
dtype=timesteps.dtype,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, **kwargs):
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
**kwargs: Optional[dict],
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
|
||||
# single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ)
|
||||
output = sample - model_output
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)
|
||||
Reference in New Issue
Block a user