From ceaf1d94540e44f15f9ffe22880ae87521a57bae Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 2 Nov 2023 18:19:20 -0600 Subject: [PATCH] Various bug fixes, wip stuff, and tweaks --- extensions_built_in/sd_trainer/SDTrainer.py | 14 +- toolkit/config_modules.py | 11 + toolkit/dataloader_mixins.py | 42 +- toolkit/inversion_utils.py | 410 ++++++++++++++++++++ toolkit/stable_diffusion_model.py | 6 + 5 files changed, 472 insertions(+), 11 deletions(-) create mode 100644 toolkit/inversion_utils.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 9e326619..f17747a2 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -80,6 +80,12 @@ class SDTrainer(BaseSDTrainProcess): prior_mask_multiplier = None target_mask_multiplier = None + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + if self.train_config.inverted_mask_prior: # we need to make the noise prediction be a masked blending of noise and prior_pred prior_mask_multiplier = 1.0 - mask_multiplier @@ -280,10 +286,10 @@ class SDTrainer(BaseSDTrainProcess): adapter_strength_max = 1.0 else: # training with assistance, we want it low - # adapter_strength_min = 0.5 - # adapter_strength_max = 0.8 - adapter_strength_min = 0.9 - adapter_strength_max = 1.1 + adapter_strength_min = 0.5 + adapter_strength_max = 0.8 + # adapter_strength_min = 0.9 + # adapter_strength_max = 1.1 adapter_conditioning_scale = torch.rand( (1,), device=self.device_torch, dtype=dtype diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4406a85b..5e1b716f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -3,6 +3,8 @@ import time from typing import List, Optional, Literal, Union import random +import torch + from toolkit.prompt_utils import PromptEmbeds ImgExt = Literal['jpg', 'png', 'webp'] @@ -184,6 +186,11 @@ class TrainConfig: self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + # match the norm of the noise before computing loss. This will help the model maintain its + #current understandin of the brightness of images. + + self.match_noise_norm = kwargs.get('match_noise_norm', False) + # set to -1 to accumulate gradients for entire epoch # warning, only do this with a small dataset or you will run out of memory self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1) @@ -406,6 +413,8 @@ class GenerateImageConfig: add_prompt_file: bool = False, # add a prompt file with generated image adapter_image_path: str = None, # path to adapter image adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning + latents: Union[torch.Tensor | None] = None, # input latent to start with, + extra_kwargs: dict = None, # extra data to save with prompt file ): self.width: int = width self.height: int = height @@ -416,6 +425,7 @@ class GenerateImageConfig: self.prompt_2: str = prompt_2 self.negative_prompt: str = negative_prompt self.negative_prompt_2: str = negative_prompt_2 + self.latents: Union[torch.Tensor | None] = latents self.output_path: str = output_path self.seed: int = seed @@ -430,6 +440,7 @@ class GenerateImageConfig: self.gen_time: int = int(time.time() * 1000) self.adapter_image_path: str = adapter_image_path self.adapter_conditioning_scale: float = adapter_conditioning_scale + self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {} # prompt string will override any settings above self._process_prompt_string() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 0326c931..52fe3a64 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -14,7 +14,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from toolkit.basic import flush, value_map -from toolkit.buckets import get_bucket_for_image_size +from toolkit.buckets import get_bucket_for_image_size, get_resolution from toolkit.metadata import get_meta_for_safetensors from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms @@ -718,7 +718,17 @@ class PoiFileItemDTOMixin: def setup_poi_bucket(self: 'FileItemDTO'): # we are using poi, so we need to calculate the bucket based on the poi - resolution = self.dataset_config.resolution + # TODO this will allow poi to be smaller than resolution. Could affect training image size + poi_resolution = min( + self.dataset_config.resolution, + get_resolution( + self.poi_width * self.dataset_config.scale, + self.poi_height * self.dataset_config.scale + ) + ) + + resolution = min(self.dataset_config.resolution, poi_resolution) + bucket_tolerance = self.dataset_config.bucket_tolerance initial_width = int(self.width * self.dataset_config.scale) initial_height = int(self.height * self.dataset_config.scale) @@ -727,12 +737,30 @@ class PoiFileItemDTOMixin: poi_width = int(self.poi_width * self.dataset_config.scale) poi_height = int(self.poi_height * self.dataset_config.scale) - # todo handle a poi that is smaller than resolution # determine new cropping - crop_left = random.randint(0, poi_x) - crop_right = random.randint(poi_x + poi_width, initial_width) - crop_top = random.randint(0, poi_y) - crop_bottom = random.randint(poi_y + poi_height, initial_height) + + # crop left + if poi_x > 0: + crop_left = random.randint(0, poi_x) + else: + crop_left = 0 + + # crop right + cr_min = poi_x + poi_width + if cr_min < initial_width: + crop_right = random.randint(poi_x + poi_width, initial_width) + else: + crop_right = initial_width + + if poi_y > 0: + crop_top = random.randint(0, poi_y) + else: + crop_top = 0 + + if poi_y + poi_height < initial_height: + crop_bottom = random.randint(poi_y + poi_height, initial_height) + else: + crop_bottom = initial_height new_width = crop_right - crop_left new_height = crop_bottom - crop_top diff --git a/toolkit/inversion_utils.py b/toolkit/inversion_utils.py new file mode 100644 index 00000000..51a61d83 --- /dev/null +++ b/toolkit/inversion_utils.py @@ -0,0 +1,410 @@ +# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py + +import torch +import os +from tqdm import tqdm + +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from toolkit.stable_diffusion_model import StableDiffusion + + +def mu_tilde(model, xt, x0, timestep): + "mu_tilde(x_t, x_0) DDPM paper eq. 7" + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + alpha_t = model.scheduler.alphas[timestep] + beta_t = 1 - alpha_t + alpha_bar = model.scheduler.alphas_cumprod[timestep] + return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + ( + (alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt + + +def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50): + """ + Samples from P(x_1:T|x_0) + """ + # torch.manual_seed(43256465436) + alpha_bar = sd.noise_scheduler.alphas_cumprod + sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5 + alphas = sd.noise_scheduler.alphas + betas = 1 - alphas + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16) + for t in reversed(timesteps): + idx = t_to_idx[int(t)] + xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t] + xts = torch.cat([xts, sample], dim=0) + + return xts + + +def encode_text(model, prompts): + text_input = model.tokenizer( + prompts, + padding="max_length", + max_length=model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + with torch.no_grad(): + text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] + return text_encoding + + +def forward_step(sd: StableDiffusion, model_output, timestep, sample): + next_timestep = min( + sd.noise_scheduler.config['num_train_timesteps'] - 2, + timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + ) + + # 2. compute alphas, betas + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 5. TODO: simple noising implementation + next_sample = sd.noise_scheduler.add_noise( + pred_original_sample, + model_output, + torch.LongTensor([next_timestep])) + return next_sample + + +def get_variance(sd: StableDiffusion, timestep): # , prev_timestep): + prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + return variance + + +def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor): + VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1) + if sd.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + +def inversion_forward_process( + sd: StableDiffusion, + sample: torch.Tensor, + conditional_embeddings: PromptEmbeds, + unconditional_embeddings: PromptEmbeds, + etas=None, + prog_bar=False, + cfg_scale=3.5, + num_inference_steps=50, eps=None +): + current_num_timesteps = len(sd.noise_scheduler.timesteps) + sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device) + + timesteps = sd.noise_scheduler.timesteps.to(sd.device) + # variance_noise_shape = ( + # num_inference_steps, + # sd.unet.in_channels, + # sd.unet.sample_size, + # sd.unet.sample_size + # ) + variance_noise_shape = list(sample.shape) + variance_noise_shape[0] = num_inference_steps + if etas is None or (type(etas) in [int, float] and etas == 0): + eta_is_zero = True + zs = None + else: + eta_is_zero = False + if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps + xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps) + alpha_bar = sd.noise_scheduler.alphas_cumprod + zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16) + + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + noisy_sample = sample + op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) + + for timestep in op: + idx = t_to_idx[int(timestep)] + # 1. predict noise residual + if not eta_is_zero: + noisy_sample = xts[idx][None] + + added_cond_kwargs = {} + + with torch.no_grad(): + text_embeddings = train_tools.concat_prompt_embeddings( + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + 1, # batch size + ) + if sd.is_xl: + add_time_ids = get_time_ids_from_latents(sd, noisy_sample) + # add extra for cfg + add_time_ids = torch.cat( + [add_time_ids] * 2, dim=0 + ) + + added_cond_kwargs = { + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + # double up for cfg + latent_model_input = torch.cat( + [noisy_sample] * 2, dim=0 + ) + + noise_pred = sd.unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + # out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding) + # cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings) + + noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) + + if eta_is_zero: + # 2. compute more noisy image and set x_t -> x_t+1 + noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample) + xts = None + + else: + xtm1 = xts[idx + 1][None] + # pred of x0 + pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[ + timestep] ** 0.5 + + # direction to xt + prev_timestep = timestep - sd.noise_scheduler.config[ + 'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps + alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod + + variance = get_variance(sd, timestep) + pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred + + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) + zs[idx] = z + + # correction to avoid error accumulation + xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z + xts[idx + 1] = xtm1 + + if not zs is None: + zs[-1] = torch.zeros_like(zs[-1]) + + # restore timesteps + sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device) + + return noisy_sample, zs, xts + + +# +# def inversion_forward_process( +# model, +# sample, +# etas=None, +# prog_bar=False, +# prompt="", +# cfg_scale=3.5, +# num_inference_steps=50, eps=None +# ): +# if not prompt == "": +# text_embeddings = encode_text(model, prompt) +# uncond_embedding = encode_text(model, "") +# timesteps = model.scheduler.timesteps.to(model.device) +# variance_noise_shape = ( +# num_inference_steps, +# model.unet.in_channels, +# model.unet.sample_size, +# model.unet.sample_size) +# if etas is None or (type(etas) in [int, float] and etas == 0): +# eta_is_zero = True +# zs = None +# else: +# eta_is_zero = False +# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps +# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps) +# alpha_bar = model.scheduler.alphas_cumprod +# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16) +# +# t_to_idx = {int(v): k for k, v in enumerate(timesteps)} +# noisy_sample = sample +# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps) +# +# for t in op: +# idx = t_to_idx[int(t)] +# # 1. predict noise residual +# if not eta_is_zero: +# noisy_sample = xts[idx][None] +# +# with torch.no_grad(): +# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding) +# if not prompt == "": +# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings) +# +# if not prompt == "": +# ## classifier free guidance +# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) +# else: +# noise_pred = out.sample +# +# if eta_is_zero: +# # 2. compute more noisy image and set x_t -> x_t+1 +# noisy_sample = forward_step(model, noise_pred, t, noisy_sample) +# +# else: +# xtm1 = xts[idx + 1][None] +# # pred of x0 +# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 +# +# # direction to xt +# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps +# alpha_prod_t_prev = model.scheduler.alphas_cumprod[ +# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod +# +# variance = get_variance(model, t) +# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred +# +# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction +# +# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) +# zs[idx] = z +# +# # correction to avoid error accumulation +# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z +# xts[idx + 1] = xtm1 +# +# if not zs is None: +# zs[-1] = torch.zeros_like(zs[-1]) +# +# return noisy_sample, zs, xts + + +def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None): + # 1. get previous step value (=t-1) + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + # 2. compute alphas, betas + alpha_prod_t = model.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = model.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self.scheduler._get_variance(timestep, prev_timestep) + variance = get_variance(model, timestep) # , prev_timestep) + std_dev_t = eta * variance ** (0.5) + # Take care of asymetric reverse process (asyrp) + model_output_direction = model_output + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction + pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # 8. Add noice if eta > 0 + if eta > 0: + if variance_noise is None: + variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16) + sigma_z = eta * variance ** (0.5) * variance_noise + prev_sample = prev_sample + sigma_z + + return prev_sample + + +def inversion_reverse_process( + model, + xT, + etas=0, + prompts="", + cfg_scales=None, + prog_bar=False, + zs=None, + controller=None, + asyrp=False): + batch_size = len(prompts) + + cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16) + + text_embeddings = encode_text(model, prompts) + uncond_embedding = encode_text(model, [""] * batch_size) + + if etas is None: etas = 0 + if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps + assert len(etas) == model.scheduler.num_inference_steps + timesteps = model.scheduler.timesteps.to(model.device) + + xt = xT.expand(batch_size, -1, -1, -1) + op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:] + + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} + + for t in op: + idx = t_to_idx[int(t)] + ## Unconditional embedding + with torch.no_grad(): + uncond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=uncond_embedding) + + ## Conditional embedding + if prompts: + with torch.no_grad(): + cond_out = model.unet.forward(xt, timestep=t, + encoder_hidden_states=text_embeddings) + + z = zs[idx] if not zs is None else None + z = z.expand(batch_size, -1, -1, -1) + if prompts: + ## classifier free guidance + noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample) + else: + noise_pred = uncond_out.sample + # 2. compute less noisy image and set x_t -> x_t-1 + xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z) + if controller is not None: + xt = controller.step_callback(xt) + return xt, zs diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index dbacfa79..5b8b93d2 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -424,6 +424,10 @@ class StableDiffusion: if sampler.startswith("sample_"): extra['use_karras_sigmas'] = True + extra = { + **extra, + **gen_config.extra_kwargs, + } img = pipeline( # prompt=gen_config.prompt, @@ -439,6 +443,7 @@ class StableDiffusion: num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, guidance_rescale=grs, + latents=gen_config.latents, **extra ).images[0] else: @@ -451,6 +456,7 @@ class StableDiffusion: width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, **extra ).images[0]