mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various bug fixes, wip stuff, and tweaks
This commit is contained in:
@@ -80,6 +80,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
|
||||
if self.train_config.match_noise_norm:
|
||||
# match the norm of the noise
|
||||
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
prior_mask_multiplier = 1.0 - mask_multiplier
|
||||
@@ -280,10 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
# adapter_strength_min = 0.5
|
||||
# adapter_strength_max = 0.8
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.1
|
||||
adapter_strength_min = 0.5
|
||||
adapter_strength_max = 0.8
|
||||
# adapter_strength_min = 0.9
|
||||
# adapter_strength_max = 1.1
|
||||
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
|
||||
@@ -3,6 +3,8 @@ import time
|
||||
from typing import List, Optional, Literal, Union
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
@@ -184,6 +186,11 @@ class TrainConfig:
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||
#current understandin of the brightness of images.
|
||||
|
||||
self.match_noise_norm = kwargs.get('match_noise_norm', False)
|
||||
|
||||
# set to -1 to accumulate gradients for entire epoch
|
||||
# warning, only do this with a small dataset or you will run out of memory
|
||||
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||
@@ -406,6 +413,8 @@ class GenerateImageConfig:
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
adapter_image_path: str = None, # path to adapter image
|
||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -416,6 +425,7 @@ class GenerateImageConfig:
|
||||
self.prompt_2: str = prompt_2
|
||||
self.negative_prompt: str = negative_prompt
|
||||
self.negative_prompt_2: str = negative_prompt_2
|
||||
self.latents: Union[torch.Tensor | None] = latents
|
||||
|
||||
self.output_path: str = output_path
|
||||
self.seed: int = seed
|
||||
@@ -430,6 +440,7 @@ class GenerateImageConfig:
|
||||
self.gen_time: int = int(time.time() * 1000)
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
@@ -14,7 +14,7 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
@@ -718,7 +718,17 @@ class PoiFileItemDTOMixin:
|
||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||
# we are using poi, so we need to calculate the bucket based on the poi
|
||||
|
||||
resolution = self.dataset_config.resolution
|
||||
# TODO this will allow poi to be smaller than resolution. Could affect training image size
|
||||
poi_resolution = min(
|
||||
self.dataset_config.resolution,
|
||||
get_resolution(
|
||||
self.poi_width * self.dataset_config.scale,
|
||||
self.poi_height * self.dataset_config.scale
|
||||
)
|
||||
)
|
||||
|
||||
resolution = min(self.dataset_config.resolution, poi_resolution)
|
||||
|
||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||
initial_width = int(self.width * self.dataset_config.scale)
|
||||
initial_height = int(self.height * self.dataset_config.scale)
|
||||
@@ -727,12 +737,30 @@ class PoiFileItemDTOMixin:
|
||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||
|
||||
# todo handle a poi that is smaller than resolution
|
||||
# determine new cropping
|
||||
crop_left = random.randint(0, poi_x)
|
||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||
crop_top = random.randint(0, poi_y)
|
||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||
|
||||
# crop left
|
||||
if poi_x > 0:
|
||||
crop_left = random.randint(0, poi_x)
|
||||
else:
|
||||
crop_left = 0
|
||||
|
||||
# crop right
|
||||
cr_min = poi_x + poi_width
|
||||
if cr_min < initial_width:
|
||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||
else:
|
||||
crop_right = initial_width
|
||||
|
||||
if poi_y > 0:
|
||||
crop_top = random.randint(0, poi_y)
|
||||
else:
|
||||
crop_top = 0
|
||||
|
||||
if poi_y + poi_height < initial_height:
|
||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||
else:
|
||||
crop_bottom = initial_height
|
||||
|
||||
new_width = crop_right - crop_left
|
||||
new_height = crop_bottom - crop_top
|
||||
|
||||
410
toolkit/inversion_utils.py
Normal file
410
toolkit/inversion_utils.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py
|
||||
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
def mu_tilde(model, xt, x0, timestep):
|
||||
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
||||
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
alpha_t = model.scheduler.alphas[timestep]
|
||||
beta_t = 1 - alpha_t
|
||||
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
||||
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + (
|
||||
(alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
|
||||
|
||||
|
||||
def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50):
|
||||
"""
|
||||
Samples from P(x_1:T|x_0)
|
||||
"""
|
||||
# torch.manual_seed(43256465436)
|
||||
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
||||
alphas = sd.noise_scheduler.alphas
|
||||
betas = 1 - alphas
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# sd.unet.in_channels,
|
||||
# sd.unet.sample_size,
|
||||
# sd.unet.sample_size)
|
||||
variance_noise_shape = list(sample.shape)
|
||||
variance_noise_shape[0] = num_inference_steps
|
||||
|
||||
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16)
|
||||
for t in reversed(timesteps):
|
||||
idx = t_to_idx[int(t)]
|
||||
xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t]
|
||||
xts = torch.cat([xts, sample], dim=0)
|
||||
|
||||
return xts
|
||||
|
||||
|
||||
def encode_text(model, prompts):
|
||||
text_input = model.tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with torch.no_grad():
|
||||
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||
return text_encoding
|
||||
|
||||
|
||||
def forward_step(sd: StableDiffusion, model_output, timestep, sample):
|
||||
next_timestep = min(
|
||||
sd.noise_scheduler.config['num_train_timesteps'] - 2,
|
||||
timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 5. TODO: simple noising implementation
|
||||
next_sample = sd.noise_scheduler.add_noise(
|
||||
pred_original_sample,
|
||||
model_output,
|
||||
torch.LongTensor([next_timestep]))
|
||||
return next_sample
|
||||
|
||||
|
||||
def get_variance(sd: StableDiffusion, timestep): # , prev_timestep):
|
||||
prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
return variance
|
||||
|
||||
|
||||
def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1)
|
||||
if sd.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
dtype = latents.dtype
|
||||
# just do it without any cropping nonsense
|
||||
target_size = (height, width)
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||
|
||||
batch_time_ids = torch.cat(
|
||||
[add_time_ids for _ in range(bs)]
|
||||
)
|
||||
return batch_time_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def inversion_forward_process(
|
||||
sd: StableDiffusion,
|
||||
sample: torch.Tensor,
|
||||
conditional_embeddings: PromptEmbeds,
|
||||
unconditional_embeddings: PromptEmbeds,
|
||||
etas=None,
|
||||
prog_bar=False,
|
||||
cfg_scale=3.5,
|
||||
num_inference_steps=50, eps=None
|
||||
):
|
||||
current_num_timesteps = len(sd.noise_scheduler.timesteps)
|
||||
sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device)
|
||||
|
||||
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# sd.unet.in_channels,
|
||||
# sd.unet.sample_size,
|
||||
# sd.unet.sample_size
|
||||
# )
|
||||
variance_noise_shape = list(sample.shape)
|
||||
variance_noise_shape[0] = num_inference_steps
|
||||
if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||
eta_is_zero = True
|
||||
zs = None
|
||||
else:
|
||||
eta_is_zero = False
|
||||
if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps
|
||||
xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps)
|
||||
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||
zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16)
|
||||
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
noisy_sample = sample
|
||||
op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||
|
||||
for timestep in op:
|
||||
idx = t_to_idx[int(timestep)]
|
||||
# 1. predict noise residual
|
||||
if not eta_is_zero:
|
||||
noisy_sample = xts[idx][None]
|
||||
|
||||
added_cond_kwargs = {}
|
||||
|
||||
with torch.no_grad():
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
if sd.is_xl:
|
||||
add_time_ids = get_time_ids_from_latents(sd, noisy_sample)
|
||||
# add extra for cfg
|
||||
add_time_ids = torch.cat(
|
||||
[add_time_ids] * 2, dim=0
|
||||
)
|
||||
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# double up for cfg
|
||||
latent_model_input = torch.cat(
|
||||
[noisy_sample] * 2, dim=0
|
||||
)
|
||||
|
||||
noise_pred = sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
# out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding)
|
||||
# cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings)
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if eta_is_zero:
|
||||
# 2. compute more noisy image and set x_t -> x_t+1
|
||||
noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample)
|
||||
xts = None
|
||||
|
||||
else:
|
||||
xtm1 = xts[idx + 1][None]
|
||||
# pred of x0
|
||||
pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[
|
||||
timestep] ** 0.5
|
||||
|
||||
# direction to xt
|
||||
prev_timestep = timestep - sd.noise_scheduler.config[
|
||||
'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||
|
||||
variance = get_variance(sd, timestep)
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||
|
||||
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||
zs[idx] = z
|
||||
|
||||
# correction to avoid error accumulation
|
||||
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||
xts[idx + 1] = xtm1
|
||||
|
||||
if not zs is None:
|
||||
zs[-1] = torch.zeros_like(zs[-1])
|
||||
|
||||
# restore timesteps
|
||||
sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device)
|
||||
|
||||
return noisy_sample, zs, xts
|
||||
|
||||
|
||||
#
|
||||
# def inversion_forward_process(
|
||||
# model,
|
||||
# sample,
|
||||
# etas=None,
|
||||
# prog_bar=False,
|
||||
# prompt="",
|
||||
# cfg_scale=3.5,
|
||||
# num_inference_steps=50, eps=None
|
||||
# ):
|
||||
# if not prompt == "":
|
||||
# text_embeddings = encode_text(model, prompt)
|
||||
# uncond_embedding = encode_text(model, "")
|
||||
# timesteps = model.scheduler.timesteps.to(model.device)
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# model.unet.in_channels,
|
||||
# model.unet.sample_size,
|
||||
# model.unet.sample_size)
|
||||
# if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||
# eta_is_zero = True
|
||||
# zs = None
|
||||
# else:
|
||||
# eta_is_zero = False
|
||||
# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||
# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps)
|
||||
# alpha_bar = model.scheduler.alphas_cumprod
|
||||
# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16)
|
||||
#
|
||||
# t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
# noisy_sample = sample
|
||||
# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||
#
|
||||
# for t in op:
|
||||
# idx = t_to_idx[int(t)]
|
||||
# # 1. predict noise residual
|
||||
# if not eta_is_zero:
|
||||
# noisy_sample = xts[idx][None]
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding)
|
||||
# if not prompt == "":
|
||||
# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings)
|
||||
#
|
||||
# if not prompt == "":
|
||||
# ## classifier free guidance
|
||||
# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
||||
# else:
|
||||
# noise_pred = out.sample
|
||||
#
|
||||
# if eta_is_zero:
|
||||
# # 2. compute more noisy image and set x_t -> x_t+1
|
||||
# noisy_sample = forward_step(model, noise_pred, t, noisy_sample)
|
||||
#
|
||||
# else:
|
||||
# xtm1 = xts[idx + 1][None]
|
||||
# # pred of x0
|
||||
# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
||||
#
|
||||
# # direction to xt
|
||||
# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
# alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
#
|
||||
# variance = get_variance(model, t)
|
||||
# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||
#
|
||||
# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
#
|
||||
# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||
# zs[idx] = z
|
||||
#
|
||||
# # correction to avoid error accumulation
|
||||
# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||
# xts[idx + 1] = xtm1
|
||||
#
|
||||
# if not zs is None:
|
||||
# zs[-1] = torch.zeros_like(zs[-1])
|
||||
#
|
||||
# return noisy_sample, zs, xts
|
||||
|
||||
|
||||
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
||||
variance = get_variance(model, timestep) # , prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
# Take care of asymetric reverse process (asyrp)
|
||||
model_output_direction = model_output
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
# 8. Add noice if eta > 0
|
||||
if eta > 0:
|
||||
if variance_noise is None:
|
||||
variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16)
|
||||
sigma_z = eta * variance ** (0.5) * variance_noise
|
||||
prev_sample = prev_sample + sigma_z
|
||||
|
||||
return prev_sample
|
||||
|
||||
|
||||
def inversion_reverse_process(
|
||||
model,
|
||||
xT,
|
||||
etas=0,
|
||||
prompts="",
|
||||
cfg_scales=None,
|
||||
prog_bar=False,
|
||||
zs=None,
|
||||
controller=None,
|
||||
asyrp=False):
|
||||
batch_size = len(prompts)
|
||||
|
||||
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16)
|
||||
|
||||
text_embeddings = encode_text(model, prompts)
|
||||
uncond_embedding = encode_text(model, [""] * batch_size)
|
||||
|
||||
if etas is None: etas = 0
|
||||
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||
assert len(etas) == model.scheduler.num_inference_steps
|
||||
timesteps = model.scheduler.timesteps.to(model.device)
|
||||
|
||||
xt = xT.expand(batch_size, -1, -1, -1)
|
||||
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
||||
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
||||
|
||||
for t in op:
|
||||
idx = t_to_idx[int(t)]
|
||||
## Unconditional embedding
|
||||
with torch.no_grad():
|
||||
uncond_out = model.unet.forward(xt, timestep=t,
|
||||
encoder_hidden_states=uncond_embedding)
|
||||
|
||||
## Conditional embedding
|
||||
if prompts:
|
||||
with torch.no_grad():
|
||||
cond_out = model.unet.forward(xt, timestep=t,
|
||||
encoder_hidden_states=text_embeddings)
|
||||
|
||||
z = zs[idx] if not zs is None else None
|
||||
z = z.expand(batch_size, -1, -1, -1)
|
||||
if prompts:
|
||||
## classifier free guidance
|
||||
noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
||||
else:
|
||||
noise_pred = uncond_out.sample
|
||||
# 2. compute less noisy image and set x_t -> x_t-1
|
||||
xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
|
||||
if controller is not None:
|
||||
xt = controller.step_callback(xt)
|
||||
return xt, zs
|
||||
@@ -424,6 +424,10 @@ class StableDiffusion:
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
extra = {
|
||||
**extra,
|
||||
**gen_config.extra_kwargs,
|
||||
}
|
||||
|
||||
img = pipeline(
|
||||
# prompt=gen_config.prompt,
|
||||
@@ -439,6 +443,7 @@ class StableDiffusion:
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=grs,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
else:
|
||||
@@ -451,6 +456,7 @@ class StableDiffusion:
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user