Added timestep modifications to lcm scheduler for more evenly spaced timesteps

This commit is contained in:
Jaret Burkett
2023-11-17 23:26:52 -07:00
parent 6280284d8b
commit fbec68681d
4 changed files with 277 additions and 175 deletions

View File

@@ -688,7 +688,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'): with self.timer('prepare_noise'):
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps'] num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
if self.train_config.noise_scheduler == 'lcm': if self.train_config.noise_scheduler in ['custom_lcm']:
# we store this value on our custom one
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.train_timesteps, device=self.device_torch
)
elif self.train_config.noise_scheduler in ['lcm']:
self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
) )
@@ -727,12 +732,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
) )
elif self.train_config.content_or_style == 'balanced': elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint( if min_noise_steps == max_noise_steps:
min_noise_steps, timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
max_noise_steps, else:
(batch_size,), timesteps = torch.randint(
device=self.device_torch min_noise_steps,
) max_noise_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long() timesteps = timesteps.long()
else: else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}") raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")

View File

@@ -17,7 +17,7 @@ from diffusers import (
from k_diffusion.external import CompVisDenoiser from k_diffusion.external import CompVisDenoiser
from toolkit.samplers.scheduling_ddpm import ADDPMScheduler from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
# scheduler: # scheduler:
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
@@ -78,8 +78,8 @@ def get_sampler(
scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sampler == "lcm": elif sampler == "lcm":
scheduler_cls = LCMScheduler scheduler_cls = LCMScheduler
elif sampler == "addpm": elif sampler == "custom_lcm":
scheduler_cls = ADDPMScheduler scheduler_cls = CustomLCMScheduler
config = copy.deepcopy(sdxl_sampler_config) config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args) config.update(sched_init_args)

View File

@@ -1,4 +1,4 @@
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@@ -22,13 +23,16 @@ import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class DDPMSchedulerOutput(BaseOutput): class LCMSchedulerOutput(BaseOutput):
""" """
Output class for the scheduler's `step` function output. Output class for the scheduler's `step` function output.
@@ -42,9 +46,10 @@ class DDPMSchedulerOutput(BaseOutput):
""" """
prev_sample: torch.FloatTensor prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None denoised: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
max_beta=0.999, max_beta=0.999,
@@ -89,12 +94,52 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
class ADDPMScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
""" """
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving. Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class CustomLCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`, defaults to 1000): num_train_timesteps (`int`, defaults to 1000):
@@ -106,13 +151,23 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`, defaults to `"linear"`): beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
variance_type (`str`, defaults to `"fixed_small"`): trained_betas (`np.ndarray`, *optional*):
Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
`fixed_large`, `fixed_large_log`, `learned` or `learned_range`. original_inference_steps (`int`, *optional*, defaults to 50):
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
clip_sample (`bool`, defaults to `True`): clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability. Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0): clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
prediction_type (`str`, defaults to `epsilon`, *optional*): prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -127,32 +182,38 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`): timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): timestep_scaling (`float`, defaults to 10.0):
An offset added to the inference steps. You can use a combination of `offset=1` and The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
Diffusion. error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
beta_start: float = 0.0001, beta_start: float = 0.00085,
beta_end: float = 0.02, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: str = "fixed_small", original_inference_steps: int = 50,
clip_sample: bool = True, clip_sample: bool = False,
clip_sample_range: float = 1.0,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading", timestep_spacing: str = "leading",
steps_offset: int = 0, timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -166,27 +227,55 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
self.is_training = False self.original_inference_steps = 50
# setable values # setable values
self.custom_timesteps = False
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.variance_type = variance_type self.train_timesteps = 1000
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def step_index(self):
return self._step_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
@@ -198,73 +287,13 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
The input sample. The input sample.
timestep (`int`, *optional*): timestep (`int`, *optional*):
The current timestep in the diffusion chain. The current timestep in the diffusion chain.
Returns: Returns:
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
return sample return sample
def set_timesteps( # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
original_steps = 50 if num_inference_steps != 1000 else 1000
train_timesteps = self.config['num_train_timesteps']
strength = 1.0
c = train_timesteps // original_steps
# LCM Training Steps Schedule
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
# LCM Inference Steps Schedule
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
self._step_index = None
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# we always take the log of variance, so clamp it to ensure it's not 0
variance = torch.clamp(variance, min=1e-20)
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = variance
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(variance)
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = current_beta_t
elif variance_type == "fixed_large_log":
# Glide max_log
variance = torch.log(current_beta_t)
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = torch.log(variance)
max_log = torch.log(current_beta_t)
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
@@ -298,14 +327,95 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
strength: int = 1.0,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
original_inference_steps (`int`, *optional*):
The original number of inference steps, which will be used to generate a linearly-spaced timestep
schedule (which is different from the standard `diffusers` implementation). We will then take
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
"""
original_inference_steps = self.original_inference_steps
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
original_steps = (
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
)
if original_steps > self.config.num_train_timesteps:
raise ValueError(
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if num_inference_steps > original_steps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
f" {original_steps} because the final timestep schedule will be a subset of the"
f" `original_inference_steps`-sized initial timestep schedule."
)
# LCM Timesteps Setting
# The skipping step parameter k from the paper.
k = self.config.num_train_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
if skipping_step < 1:
raise ValueError(
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
)
# LCM Inference Steps Schedule
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = lcm_origin_timesteps[inference_indices]
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.config.timestep_scaling
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
generator=None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[LCMSchedulerOutput, Tuple]:
""" """
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
@@ -320,84 +430,81 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns: Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
t = timestep if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
prev_t = self.previous_timestep(t) if self.step_index is None:
self._init_step_index(timestep)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: # 1. get previous step value
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else: else:
predicted_variance = None prev_timestep = timestep
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called # 3. Get scalings for boundary conditions
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) # 4. Compute the predicted original sample x_0 based on the model parameterization
elif self.config.prediction_type == "sample": if self.config.prediction_type == "epsilon": # noise-prediction
pred_original_sample = model_output predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "sample": # x-prediction
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for the DDPMScheduler." " `v_prediction` for `LCMScheduler`."
) )
# 3. Clip or threshold "predicted x_0" # 5. Clip or threshold "predicted x_0"
if self.config.thresholding: if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample) predicted_original_sample = self._threshold_sample(predicted_original_sample)
elif self.config.clip_sample: elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp( predicted_original_sample = predicted_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range -self.config.clip_sample_range, self.config.clip_sample_range
) )
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 6. Denoise model output using boundary conditions
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf denoised = c_out * predicted_original_sample + c_skip * sample
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # Noise is not used on the final timestep of the timestep schedule.
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # This also means that noise is not used for one-step sampling.
if self.step_index != self.num_inference_steps - 1:
# 6. Add noise noise = randn_tensor(
variance = 0 model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
if t > 0:
device = model_output.device
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
) )
if self.variance_type == "fixed_small_log": prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise else:
elif self.variance_type == "learned_range": prev_sample = denoised
variance = self._get_variance(t, predicted_variance=predicted_variance)
variance = torch.exp(0.5 * variance) * variance_noise
else:
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance # upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (pred_prev_sample,) return (prev_sample, denoised)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
@@ -421,6 +528,7 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity( def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@@ -442,19 +550,4 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
return velocity return velocity
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t

View File

@@ -836,8 +836,9 @@ class StableDiffusion:
bleed_latents: torch.FloatTensor = None, bleed_latents: torch.FloatTensor = None,
**kwargs, **kwargs,
): ):
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
for timestep in tqdm(self.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): for timestep in tqdm(timesteps_to_run, leave=False):
noise_pred = self.predict_noise( noise_pred = self.predict_noise(
latents, latents,
text_embeddings, text_embeddings,