mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
from typing import Union
|
||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||
import torch
|
||
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Tuple
|
||
from diffusers.utils import BaseOutput
|
||
|
||
|
||
@dataclass
|
||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||
"""
|
||
Output class for the scheduler's `step` function output.
|
||
|
||
Args:
|
||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||
denoising loop.
|
||
"""
|
||
|
||
prev_sample: torch.FloatTensor
|
||
|
||
|
||
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.init_noise_sigma = 1.0
|
||
self.timestep_type = "linear"
|
||
|
||
with torch.no_grad():
|
||
# create weights for timesteps
|
||
num_timesteps = 1000
|
||
|
||
# Create linear timesteps from 1000 to 1
|
||
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
|
||
|
||
self.linear_timesteps = timesteps
|
||
pass
|
||
|
||
def get_weights_for_timesteps(
|
||
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
|
||
) -> torch.Tensor:
|
||
# Get the indices of the timesteps
|
||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||
|
||
weights = 1.0
|
||
|
||
# Get the weights for the timesteps
|
||
if timestep_type == "weighted":
|
||
weights = torch.tensor(
|
||
[default_weighing_scheme[i] for i in step_indices],
|
||
device=timesteps.device,
|
||
dtype=timesteps.dtype,
|
||
)
|
||
|
||
return weights
|
||
|
||
def add_noise(
|
||
self,
|
||
original_samples: torch.Tensor,
|
||
noise: torch.Tensor,
|
||
timesteps: torch.Tensor,
|
||
) -> torch.Tensor:
|
||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||
return noisy_model_input
|
||
|
||
def scale_model_input(
|
||
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
|
||
) -> torch.Tensor:
|
||
return sample
|
||
|
||
def set_train_timesteps(self, num_timesteps, device, **kwargs):
|
||
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
|
||
self.timesteps = timesteps
|
||
return timesteps
|
||
|
||
def step(
|
||
self,
|
||
model_output: torch.FloatTensor,
|
||
timestep: Union[float, torch.FloatTensor],
|
||
sample: torch.FloatTensor,
|
||
return_dict: bool = True,
|
||
**kwargs: Optional[dict],
|
||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||
|
||
# single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ)
|
||
output = sample - model_output
|
||
if not return_dict:
|
||
return (output,)
|
||
|
||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)
|