mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Various bug fixes, wip stuff, and tweaks
This commit is contained in:
@@ -80,6 +80,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
prior_mask_multiplier = None
|
prior_mask_multiplier = None
|
||||||
target_mask_multiplier = None
|
target_mask_multiplier = None
|
||||||
|
|
||||||
|
if self.train_config.match_noise_norm:
|
||||||
|
# match the norm of the noise
|
||||||
|
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||||
|
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||||
|
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||||
|
|
||||||
if self.train_config.inverted_mask_prior:
|
if self.train_config.inverted_mask_prior:
|
||||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||||
prior_mask_multiplier = 1.0 - mask_multiplier
|
prior_mask_multiplier = 1.0 - mask_multiplier
|
||||||
@@ -280,10 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
adapter_strength_max = 1.0
|
adapter_strength_max = 1.0
|
||||||
else:
|
else:
|
||||||
# training with assistance, we want it low
|
# training with assistance, we want it low
|
||||||
# adapter_strength_min = 0.5
|
adapter_strength_min = 0.5
|
||||||
# adapter_strength_max = 0.8
|
adapter_strength_max = 0.8
|
||||||
adapter_strength_min = 0.9
|
# adapter_strength_min = 0.9
|
||||||
adapter_strength_max = 1.1
|
# adapter_strength_max = 1.1
|
||||||
|
|
||||||
adapter_conditioning_scale = torch.rand(
|
adapter_conditioning_scale = torch.rand(
|
||||||
(1,), device=self.device_torch, dtype=dtype
|
(1,), device=self.device_torch, dtype=dtype
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import time
|
|||||||
from typing import List, Optional, Literal, Union
|
from typing import List, Optional, Literal, Union
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
|
||||||
ImgExt = Literal['jpg', 'png', 'webp']
|
ImgExt = Literal['jpg', 'png', 'webp']
|
||||||
@@ -184,6 +186,11 @@ class TrainConfig:
|
|||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
|
||||||
|
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||||
|
#current understandin of the brightness of images.
|
||||||
|
|
||||||
|
self.match_noise_norm = kwargs.get('match_noise_norm', False)
|
||||||
|
|
||||||
# set to -1 to accumulate gradients for entire epoch
|
# set to -1 to accumulate gradients for entire epoch
|
||||||
# warning, only do this with a small dataset or you will run out of memory
|
# warning, only do this with a small dataset or you will run out of memory
|
||||||
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||||
@@ -406,6 +413,8 @@ class GenerateImageConfig:
|
|||||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||||
adapter_image_path: str = None, # path to adapter image
|
adapter_image_path: str = None, # path to adapter image
|
||||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||||
|
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
||||||
|
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||||
):
|
):
|
||||||
self.width: int = width
|
self.width: int = width
|
||||||
self.height: int = height
|
self.height: int = height
|
||||||
@@ -416,6 +425,7 @@ class GenerateImageConfig:
|
|||||||
self.prompt_2: str = prompt_2
|
self.prompt_2: str = prompt_2
|
||||||
self.negative_prompt: str = negative_prompt
|
self.negative_prompt: str = negative_prompt
|
||||||
self.negative_prompt_2: str = negative_prompt_2
|
self.negative_prompt_2: str = negative_prompt_2
|
||||||
|
self.latents: Union[torch.Tensor | None] = latents
|
||||||
|
|
||||||
self.output_path: str = output_path
|
self.output_path: str = output_path
|
||||||
self.seed: int = seed
|
self.seed: int = seed
|
||||||
@@ -430,6 +440,7 @@ class GenerateImageConfig:
|
|||||||
self.gen_time: int = int(time.time() * 1000)
|
self.gen_time: int = int(time.time() * 1000)
|
||||||
self.adapter_image_path: str = adapter_image_path
|
self.adapter_image_path: str = adapter_image_path
|
||||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||||
|
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
||||||
|
|
||||||
# prompt string will override any settings above
|
# prompt string will override any settings above
|
||||||
self._process_prompt_string()
|
self._process_prompt_string()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from safetensors.torch import load_file, save_file
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from toolkit.basic import flush, value_map
|
from toolkit.basic import flush, value_map
|
||||||
from toolkit.buckets import get_bucket_for_image_size
|
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@@ -718,7 +718,17 @@ class PoiFileItemDTOMixin:
|
|||||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||||
# we are using poi, so we need to calculate the bucket based on the poi
|
# we are using poi, so we need to calculate the bucket based on the poi
|
||||||
|
|
||||||
resolution = self.dataset_config.resolution
|
# TODO this will allow poi to be smaller than resolution. Could affect training image size
|
||||||
|
poi_resolution = min(
|
||||||
|
self.dataset_config.resolution,
|
||||||
|
get_resolution(
|
||||||
|
self.poi_width * self.dataset_config.scale,
|
||||||
|
self.poi_height * self.dataset_config.scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
resolution = min(self.dataset_config.resolution, poi_resolution)
|
||||||
|
|
||||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||||
initial_width = int(self.width * self.dataset_config.scale)
|
initial_width = int(self.width * self.dataset_config.scale)
|
||||||
initial_height = int(self.height * self.dataset_config.scale)
|
initial_height = int(self.height * self.dataset_config.scale)
|
||||||
@@ -727,12 +737,30 @@ class PoiFileItemDTOMixin:
|
|||||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||||
|
|
||||||
# todo handle a poi that is smaller than resolution
|
|
||||||
# determine new cropping
|
# determine new cropping
|
||||||
crop_left = random.randint(0, poi_x)
|
|
||||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
# crop left
|
||||||
crop_top = random.randint(0, poi_y)
|
if poi_x > 0:
|
||||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
crop_left = random.randint(0, poi_x)
|
||||||
|
else:
|
||||||
|
crop_left = 0
|
||||||
|
|
||||||
|
# crop right
|
||||||
|
cr_min = poi_x + poi_width
|
||||||
|
if cr_min < initial_width:
|
||||||
|
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||||
|
else:
|
||||||
|
crop_right = initial_width
|
||||||
|
|
||||||
|
if poi_y > 0:
|
||||||
|
crop_top = random.randint(0, poi_y)
|
||||||
|
else:
|
||||||
|
crop_top = 0
|
||||||
|
|
||||||
|
if poi_y + poi_height < initial_height:
|
||||||
|
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||||
|
else:
|
||||||
|
crop_bottom = initial_height
|
||||||
|
|
||||||
new_width = crop_right - crop_left
|
new_width = crop_right - crop_left
|
||||||
new_height = crop_bottom - crop_top
|
new_height = crop_bottom - crop_top
|
||||||
|
|||||||
410
toolkit/inversion_utils.py
Normal file
410
toolkit/inversion_utils.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from toolkit import train_tools
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def mu_tilde(model, xt, x0, timestep):
|
||||||
|
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
||||||
|
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||||
|
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||||
|
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||||
|
alpha_t = model.scheduler.alphas[timestep]
|
||||||
|
beta_t = 1 - alpha_t
|
||||||
|
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
||||||
|
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + (
|
||||||
|
(alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
|
||||||
|
|
||||||
|
|
||||||
|
def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50):
|
||||||
|
"""
|
||||||
|
Samples from P(x_1:T|x_0)
|
||||||
|
"""
|
||||||
|
# torch.manual_seed(43256465436)
|
||||||
|
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||||
|
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
||||||
|
alphas = sd.noise_scheduler.alphas
|
||||||
|
betas = 1 - alphas
|
||||||
|
# variance_noise_shape = (
|
||||||
|
# num_inference_steps,
|
||||||
|
# sd.unet.in_channels,
|
||||||
|
# sd.unet.sample_size,
|
||||||
|
# sd.unet.sample_size)
|
||||||
|
variance_noise_shape = list(sample.shape)
|
||||||
|
variance_noise_shape[0] = num_inference_steps
|
||||||
|
|
||||||
|
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||||
|
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||||
|
xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16)
|
||||||
|
for t in reversed(timesteps):
|
||||||
|
idx = t_to_idx[int(t)]
|
||||||
|
xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t]
|
||||||
|
xts = torch.cat([xts, sample], dim=0)
|
||||||
|
|
||||||
|
return xts
|
||||||
|
|
||||||
|
|
||||||
|
def encode_text(model, prompts):
|
||||||
|
text_input = model.tokenizer(
|
||||||
|
prompts,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=model.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||||
|
return text_encoding
|
||||||
|
|
||||||
|
|
||||||
|
def forward_step(sd: StableDiffusion, model_output, timestep, sample):
|
||||||
|
next_timestep = min(
|
||||||
|
sd.noise_scheduler.config['num_train_timesteps'] - 2,
|
||||||
|
timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. compute alphas, betas
|
||||||
|
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||||
|
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||||
|
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
|
||||||
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
|
||||||
|
# 5. TODO: simple noising implementation
|
||||||
|
next_sample = sd.noise_scheduler.add_noise(
|
||||||
|
pred_original_sample,
|
||||||
|
model_output,
|
||||||
|
torch.LongTensor([next_timestep]))
|
||||||
|
return next_sample
|
||||||
|
|
||||||
|
|
||||||
|
def get_variance(sd: StableDiffusion, timestep): # , prev_timestep):
|
||||||
|
prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||||
|
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||||
|
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||||
|
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||||
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||||
|
return variance
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor):
|
||||||
|
VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1)
|
||||||
|
if sd.is_xl:
|
||||||
|
bs, ch, h, w = list(latents.shape)
|
||||||
|
|
||||||
|
height = h * VAE_SCALE_FACTOR
|
||||||
|
width = w * VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
dtype = latents.dtype
|
||||||
|
# just do it without any cropping nonsense
|
||||||
|
target_size = (height, width)
|
||||||
|
original_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
add_time_ids = torch.tensor([add_time_ids])
|
||||||
|
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||||
|
|
||||||
|
batch_time_ids = torch.cat(
|
||||||
|
[add_time_ids for _ in range(bs)]
|
||||||
|
)
|
||||||
|
return batch_time_ids
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def inversion_forward_process(
|
||||||
|
sd: StableDiffusion,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
conditional_embeddings: PromptEmbeds,
|
||||||
|
unconditional_embeddings: PromptEmbeds,
|
||||||
|
etas=None,
|
||||||
|
prog_bar=False,
|
||||||
|
cfg_scale=3.5,
|
||||||
|
num_inference_steps=50, eps=None
|
||||||
|
):
|
||||||
|
current_num_timesteps = len(sd.noise_scheduler.timesteps)
|
||||||
|
sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device)
|
||||||
|
|
||||||
|
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||||
|
# variance_noise_shape = (
|
||||||
|
# num_inference_steps,
|
||||||
|
# sd.unet.in_channels,
|
||||||
|
# sd.unet.sample_size,
|
||||||
|
# sd.unet.sample_size
|
||||||
|
# )
|
||||||
|
variance_noise_shape = list(sample.shape)
|
||||||
|
variance_noise_shape[0] = num_inference_steps
|
||||||
|
if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||||
|
eta_is_zero = True
|
||||||
|
zs = None
|
||||||
|
else:
|
||||||
|
eta_is_zero = False
|
||||||
|
if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps
|
||||||
|
xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps)
|
||||||
|
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||||
|
zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||||
|
noisy_sample = sample
|
||||||
|
op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||||
|
|
||||||
|
for timestep in op:
|
||||||
|
idx = t_to_idx[int(timestep)]
|
||||||
|
# 1. predict noise residual
|
||||||
|
if not eta_is_zero:
|
||||||
|
noisy_sample = xts[idx][None]
|
||||||
|
|
||||||
|
added_cond_kwargs = {}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||||
|
unconditional_embeddings, # negative embedding
|
||||||
|
conditional_embeddings, # positive embedding
|
||||||
|
1, # batch size
|
||||||
|
)
|
||||||
|
if sd.is_xl:
|
||||||
|
add_time_ids = get_time_ids_from_latents(sd, noisy_sample)
|
||||||
|
# add extra for cfg
|
||||||
|
add_time_ids = torch.cat(
|
||||||
|
[add_time_ids] * 2, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
# double up for cfg
|
||||||
|
latent_model_input = torch.cat(
|
||||||
|
[noisy_sample] * 2, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_pred = sd.unet(
|
||||||
|
latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
|
||||||
|
# out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding)
|
||||||
|
# cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings)
|
||||||
|
|
||||||
|
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
if eta_is_zero:
|
||||||
|
# 2. compute more noisy image and set x_t -> x_t+1
|
||||||
|
noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample)
|
||||||
|
xts = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
xtm1 = xts[idx + 1][None]
|
||||||
|
# pred of x0
|
||||||
|
pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[
|
||||||
|
timestep] ** 0.5
|
||||||
|
|
||||||
|
# direction to xt
|
||||||
|
prev_timestep = timestep - sd.noise_scheduler.config[
|
||||||
|
'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||||
|
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||||
|
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||||
|
|
||||||
|
variance = get_variance(sd, timestep)
|
||||||
|
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||||
|
|
||||||
|
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
|
||||||
|
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||||
|
zs[idx] = z
|
||||||
|
|
||||||
|
# correction to avoid error accumulation
|
||||||
|
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||||
|
xts[idx + 1] = xtm1
|
||||||
|
|
||||||
|
if not zs is None:
|
||||||
|
zs[-1] = torch.zeros_like(zs[-1])
|
||||||
|
|
||||||
|
# restore timesteps
|
||||||
|
sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device)
|
||||||
|
|
||||||
|
return noisy_sample, zs, xts
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# def inversion_forward_process(
|
||||||
|
# model,
|
||||||
|
# sample,
|
||||||
|
# etas=None,
|
||||||
|
# prog_bar=False,
|
||||||
|
# prompt="",
|
||||||
|
# cfg_scale=3.5,
|
||||||
|
# num_inference_steps=50, eps=None
|
||||||
|
# ):
|
||||||
|
# if not prompt == "":
|
||||||
|
# text_embeddings = encode_text(model, prompt)
|
||||||
|
# uncond_embedding = encode_text(model, "")
|
||||||
|
# timesteps = model.scheduler.timesteps.to(model.device)
|
||||||
|
# variance_noise_shape = (
|
||||||
|
# num_inference_steps,
|
||||||
|
# model.unet.in_channels,
|
||||||
|
# model.unet.sample_size,
|
||||||
|
# model.unet.sample_size)
|
||||||
|
# if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||||
|
# eta_is_zero = True
|
||||||
|
# zs = None
|
||||||
|
# else:
|
||||||
|
# eta_is_zero = False
|
||||||
|
# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||||
|
# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps)
|
||||||
|
# alpha_bar = model.scheduler.alphas_cumprod
|
||||||
|
# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16)
|
||||||
|
#
|
||||||
|
# t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||||
|
# noisy_sample = sample
|
||||||
|
# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||||
|
#
|
||||||
|
# for t in op:
|
||||||
|
# idx = t_to_idx[int(t)]
|
||||||
|
# # 1. predict noise residual
|
||||||
|
# if not eta_is_zero:
|
||||||
|
# noisy_sample = xts[idx][None]
|
||||||
|
#
|
||||||
|
# with torch.no_grad():
|
||||||
|
# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding)
|
||||||
|
# if not prompt == "":
|
||||||
|
# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings)
|
||||||
|
#
|
||||||
|
# if not prompt == "":
|
||||||
|
# ## classifier free guidance
|
||||||
|
# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
||||||
|
# else:
|
||||||
|
# noise_pred = out.sample
|
||||||
|
#
|
||||||
|
# if eta_is_zero:
|
||||||
|
# # 2. compute more noisy image and set x_t -> x_t+1
|
||||||
|
# noisy_sample = forward_step(model, noise_pred, t, noisy_sample)
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# xtm1 = xts[idx + 1][None]
|
||||||
|
# # pred of x0
|
||||||
|
# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
||||||
|
#
|
||||||
|
# # direction to xt
|
||||||
|
# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||||
|
# alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||||
|
# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||||
|
#
|
||||||
|
# variance = get_variance(model, t)
|
||||||
|
# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||||
|
#
|
||||||
|
# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
#
|
||||||
|
# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||||
|
# zs[idx] = z
|
||||||
|
#
|
||||||
|
# # correction to avoid error accumulation
|
||||||
|
# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||||
|
# xts[idx + 1] = xtm1
|
||||||
|
#
|
||||||
|
# if not zs is None:
|
||||||
|
# zs[-1] = torch.zeros_like(zs[-1])
|
||||||
|
#
|
||||||
|
# return noisy_sample, zs, xts
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
|
||||||
|
# 1. get previous step value (=t-1)
|
||||||
|
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||||
|
# 2. compute alphas, betas
|
||||||
|
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
||||||
|
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||||
|
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||||
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||||
|
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
||||||
|
variance = get_variance(model, timestep) # , prev_timestep)
|
||||||
|
std_dev_t = eta * variance ** (0.5)
|
||||||
|
# Take care of asymetric reverse process (asyrp)
|
||||||
|
model_output_direction = model_output
|
||||||
|
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
||||||
|
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
||||||
|
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
# 8. Add noice if eta > 0
|
||||||
|
if eta > 0:
|
||||||
|
if variance_noise is None:
|
||||||
|
variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16)
|
||||||
|
sigma_z = eta * variance ** (0.5) * variance_noise
|
||||||
|
prev_sample = prev_sample + sigma_z
|
||||||
|
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
|
||||||
|
def inversion_reverse_process(
|
||||||
|
model,
|
||||||
|
xT,
|
||||||
|
etas=0,
|
||||||
|
prompts="",
|
||||||
|
cfg_scales=None,
|
||||||
|
prog_bar=False,
|
||||||
|
zs=None,
|
||||||
|
controller=None,
|
||||||
|
asyrp=False):
|
||||||
|
batch_size = len(prompts)
|
||||||
|
|
||||||
|
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
text_embeddings = encode_text(model, prompts)
|
||||||
|
uncond_embedding = encode_text(model, [""] * batch_size)
|
||||||
|
|
||||||
|
if etas is None: etas = 0
|
||||||
|
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||||
|
assert len(etas) == model.scheduler.num_inference_steps
|
||||||
|
timesteps = model.scheduler.timesteps.to(model.device)
|
||||||
|
|
||||||
|
xt = xT.expand(batch_size, -1, -1, -1)
|
||||||
|
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
||||||
|
|
||||||
|
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
||||||
|
|
||||||
|
for t in op:
|
||||||
|
idx = t_to_idx[int(t)]
|
||||||
|
## Unconditional embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
uncond_out = model.unet.forward(xt, timestep=t,
|
||||||
|
encoder_hidden_states=uncond_embedding)
|
||||||
|
|
||||||
|
## Conditional embedding
|
||||||
|
if prompts:
|
||||||
|
with torch.no_grad():
|
||||||
|
cond_out = model.unet.forward(xt, timestep=t,
|
||||||
|
encoder_hidden_states=text_embeddings)
|
||||||
|
|
||||||
|
z = zs[idx] if not zs is None else None
|
||||||
|
z = z.expand(batch_size, -1, -1, -1)
|
||||||
|
if prompts:
|
||||||
|
## classifier free guidance
|
||||||
|
noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
||||||
|
else:
|
||||||
|
noise_pred = uncond_out.sample
|
||||||
|
# 2. compute less noisy image and set x_t -> x_t-1
|
||||||
|
xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
|
||||||
|
if controller is not None:
|
||||||
|
xt = controller.step_callback(xt)
|
||||||
|
return xt, zs
|
||||||
@@ -424,6 +424,10 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if sampler.startswith("sample_"):
|
if sampler.startswith("sample_"):
|
||||||
extra['use_karras_sigmas'] = True
|
extra['use_karras_sigmas'] = True
|
||||||
|
extra = {
|
||||||
|
**extra,
|
||||||
|
**gen_config.extra_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
img = pipeline(
|
img = pipeline(
|
||||||
# prompt=gen_config.prompt,
|
# prompt=gen_config.prompt,
|
||||||
@@ -439,6 +443,7 @@ class StableDiffusion:
|
|||||||
num_inference_steps=gen_config.num_inference_steps,
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
guidance_scale=gen_config.guidance_scale,
|
guidance_scale=gen_config.guidance_scale,
|
||||||
guidance_rescale=grs,
|
guidance_rescale=grs,
|
||||||
|
latents=gen_config.latents,
|
||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
else:
|
else:
|
||||||
@@ -451,6 +456,7 @@ class StableDiffusion:
|
|||||||
width=gen_config.width,
|
width=gen_config.width,
|
||||||
num_inference_steps=gen_config.num_inference_steps,
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
guidance_scale=gen_config.guidance_scale,
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user