Added timestep modifications to lcm scheduler for more evenly spaced timesteps

This commit is contained in:
Jaret Burkett
2023-11-17 23:26:52 -07:00
parent 6280284d8b
commit fbec68681d
4 changed files with 277 additions and 175 deletions

View File

@@ -688,7 +688,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'):
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
if self.train_config.noise_scheduler == 'lcm':
if self.train_config.noise_scheduler in ['custom_lcm']:
# we store this value on our custom one
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.train_timesteps, device=self.device_torch
)
elif self.train_config.noise_scheduler in ['lcm']:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
@@ -727,12 +732,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
min_noise_steps,
max_noise_steps,
(batch_size,),
device=self.device_torch
)
if min_noise_steps == max_noise_steps:
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
timesteps = torch.randint(
min_noise_steps,
max_noise_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")

View File

@@ -17,7 +17,7 @@ from diffusers import (
from k_diffusion.external import CompVisDenoiser
from toolkit.samplers.scheduling_ddpm import ADDPMScheduler
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
@@ -78,8 +78,8 @@ def get_sampler(
scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sampler == "lcm":
scheduler_cls = LCMScheduler
elif sampler == "addpm":
scheduler_cls = ADDPMScheduler
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args)

View File

@@ -1,4 +1,4 @@
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
@@ -22,13 +23,16 @@ import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class DDPMSchedulerOutput(BaseOutput):
class LCMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
@@ -42,9 +46,10 @@ class DDPMSchedulerOutput(BaseOutput):
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
denoised: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
@@ -89,12 +94,52 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32)
class ADDPMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
"""
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class CustomLCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`, defaults to 1000):
@@ -106,13 +151,23 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
variance_type (`str`, defaults to `"fixed_small"`):
Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`,
`fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
original_inference_steps (`int`, *optional*, defaults to 50):
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -127,32 +182,38 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
timestep_scaling (`float`, defaults to 10.0):
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
original_inference_steps: int = 50,
clip_sample: bool = False,
clip_sample_range: float = 1.0,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -166,27 +227,55 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
self.is_training = False
self.original_inference_steps = 50
# setable values
self.custom_timesteps = False
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.variance_type = variance_type
self.train_timesteps = 1000
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def step_index(self):
return self._step_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
@@ -198,73 +287,13 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
original_steps = 50 if num_inference_steps != 1000 else 1000
train_timesteps = self.config['num_train_timesteps']
strength = 1.0
c = train_timesteps // original_steps
# LCM Training Steps Schedule
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
# LCM Inference Steps Schedule
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
self._step_index = None
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# we always take the log of variance, so clamp it to ensure it's not 0
variance = torch.clamp(variance, min=1e-20)
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = variance
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(variance)
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = current_beta_t
elif variance_type == "fixed_large_log":
# Glide max_log
variance = torch.log(current_beta_t)
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = torch.log(variance)
max_log = torch.log(current_beta_t)
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
@@ -298,14 +327,95 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
strength: int = 1.0,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
original_inference_steps (`int`, *optional*):
The original number of inference steps, which will be used to generate a linearly-spaced timestep
schedule (which is different from the standard `diffusers` implementation). We will then take
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
"""
original_inference_steps = self.original_inference_steps
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
original_steps = (
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
)
if original_steps > self.config.num_train_timesteps:
raise ValueError(
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if num_inference_steps > original_steps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
f" {original_steps} because the final timestep schedule will be a subset of the"
f" `original_inference_steps`-sized initial timestep schedule."
)
# LCM Timesteps Setting
# The skipping step parameter k from the paper.
k = self.config.num_train_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
if skipping_step < 1:
raise ValueError(
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
)
# LCM Inference Steps Schedule
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = lcm_origin_timesteps[inference_indices]
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.config.timestep_scaling
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
) -> Union[LCMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -320,84 +430,81 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
t = timestep
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
prev_t = self.previous_timestep(t)
if self.step_index is None:
self._init_step_index(timestep)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
# 1. get previous step value
prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else:
predicted_variance = None
prev_timestep = timestep
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.config.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.config.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for the DDPMScheduler."
" `v_prediction` for `LCMScheduler`."
)
# 3. Clip or threshold "predicted x_0"
# 5. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
predicted_original_sample = self._threshold_sample(predicted_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
predicted_original_sample = predicted_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 6. Denoise model output using boundary conditions
denoised = c_out * predicted_original_sample + c_skip * sample
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
device = model_output.device
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used on the final timestep of the timestep schedule.
# This also means that noise is not used for one-step sampling.
if self.step_index != self.num_inference_steps - 1:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
elif self.variance_type == "learned_range":
variance = self._get_variance(t, predicted_variance=predicted_variance)
variance = torch.exp(0.5 * variance) * variance_noise
else:
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
pred_prev_sample = pred_prev_sample + variance
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (pred_prev_sample,)
return (prev_sample, denoised)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -421,6 +528,7 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
@@ -442,19 +550,4 @@ class ADDPMScheduler(SchedulerMixin, ConfigMixin):
return velocity
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t
return self.config.num_train_timesteps

View File

@@ -836,8 +836,9 @@ class StableDiffusion:
bleed_latents: torch.FloatTensor = None,
**kwargs,
):
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
for timestep in tqdm(self.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
for timestep in tqdm(timesteps_to_run, leave=False):
noise_pred = self.predict_noise(
latents,
text_embeddings,