Various bug fixes and improvements

This commit is contained in:
Jaret Burkett
2023-08-12 05:59:50 -06:00
parent 67dfd9ced0
commit 379992d89e
5 changed files with 180 additions and 93 deletions

View File

@@ -2,11 +2,12 @@ import copy
import random import random
from collections import OrderedDict from collections import OrderedDict
import os import os
from contextlib import nullcontext
from typing import Optional, Union, List from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from toolkit.data_loader import PairedImageDataset from toolkit.data_loader import PairedImageDataset
from toolkit.prompt_utils import concat_prompt_embeds from toolkit.prompt_utils import concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
import gc import gc
from toolkit import train_tools from toolkit import train_tools
@@ -80,34 +81,16 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
imgs, prompts = batch imgs, prompts = batch
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
# split batched images in half so left is negative and right is positive # split batched images in half so left is negative and right is positive
negative_images, positive_images = torch.chunk(imgs, 2, dim=3) negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
positive_latents = self.sd.encode_images(positive_images)
negative_latents = self.sd.encode_images(negative_images)
height = positive_images.shape[2] height = positive_images.shape[2]
width = positive_images.shape[3] width = positive_images.shape[3]
batch_size = positive_images.shape[0] batch_size = positive_images.shape[0]
# encode the images
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
positive_latents = positive_latents * 0.18215
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
negative_latents = negative_latents * 0.18215
embedding_list = []
negative_embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
# just empty for now
# todo cache this?
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
negative_embedding_list.append(negative_embed)
conditional_embeds = concat_prompt_embeds(embedding_list)
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
if self.train_config.gradient_checkpointing: if self.train_config.gradient_checkpointing:
# may get disabled elsewhere # may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing() self.sd.unet.enable_gradient_checkpointing()
@@ -115,26 +98,12 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noise_scheduler = self.sd.noise_scheduler noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer optimizer = self.optimizer
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(neg, pos, gs, cts, dn):
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
neg, # negative prompt
pos, # positive prompt
self.train_config.batch_size,
),
timestep=cts,
guidance_scale=gs,
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch self.train_config.max_denoising_steps, device=self.device_torch
) )
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
timesteps = timesteps.long() timesteps = timesteps.long()
# get noise # get noise
@@ -147,6 +116,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
if do_mirror_loss: if do_mirror_loss:
# mirror the noise # mirror the noise
# torch shape is [batch, channels, height, width]
noise_negative = torch.flip(noise_positive.clone(), dims=[3]) noise_negative = torch.flip(noise_positive.clone(), dims=[3])
else: else:
noise_negative = noise_positive.clone() noise_negative = noise_positive.clone()
@@ -159,8 +129,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0) noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0) timesteps = torch.cat([timesteps, timesteps], dim=0)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
network_multiplier = [1.0, -1.0] network_multiplier = [1.0, -1.0]
flush() flush()
@@ -170,22 +138,31 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss_mirror_float = None loss_mirror_float = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
# text encoding
embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
loss_list = []
# do positive first
self.network.multiplier = network_multiplier self.network.multiplier = network_multiplier
noise_pred = get_noise_pred( noise_pred = self.sd.predict_noise(
unconditional_embeds, latents=noisy_latents,
conditional_embeds, conditional_embeddings=conditional_embeds,
1, timestep=timesteps,
timesteps,
noisy_latents
) )
if self.sd.is_v2: # check is vpred, don't want to track it down right now if self.sd.prediction_type == 'v_prediction':
# v-parameterization training # v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else: else:
@@ -199,7 +176,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = loss.mean() loss = loss.mean()
loss_slide_float = loss.item() loss_slide_float = loss.item()
if do_mirror_loss: if do_mirror_loss:
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0) noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative # mirror the negative
@@ -221,7 +197,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
# reset network # reset network
self.network.multiplier = 1.0 self.network.multiplier = 1.0

View File

@@ -9,17 +9,21 @@ config:
# for tensorboard logging # for tensorboard logging
log_dir: "/home/jaret/Dev/.tensorboard" log_dir: "/home/jaret/Dev/.tensorboard"
network: network:
type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers type: "lora"
rank: 16 linear: 64
alpha: 8 linear_alpha: 32
conv: 32
conv_alpha: 16
train: train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 1000 steps: 5000
lr: 5e-5 lr: 1e-4
train_unet: true train_unet: true
gradient_checkpointing: true gradient_checkpointing: true
train_text_encoder: false train_text_encoder: true
optimizer: "lion8bit" optimizer: "adamw"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant" lr_scheduler: "constant"
max_denoising_steps: 1000 max_denoising_steps: 1000
batch_size: 1 batch_size: 1
@@ -36,11 +40,11 @@ config:
is_v_pred: false # for v-prediction models (most v2 models) is_v_pred: false # for v-prediction models (most v2 models)
save: save:
dtype: float16 # precision to save dtype: float16 # precision to save
save_every: 100 # save every this many steps save_every: 1000 # save every this many steps
max_step_saves_to_keep: 2 # only affects step counts max_step_saves_to_keep: 2 # only affects step counts
sample: sample:
sampler: "ddpm" # must match train.noise_scheduler sampler: "ddpm" # must match train.noise_scheduler
sample_every: 20 # sample every this many steps sample_every: 100 # sample every this many steps
width: 512 width: 512
height: 512 height: 512
prompts: prompts:
@@ -81,6 +85,8 @@ config:
- 512 - 512
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner" slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
target_class: "photo of a person" target_class: "photo of a person"
# additional_losses:
# - "mirror"
meta: meta:

View File

@@ -97,21 +97,25 @@ class LoRAModule(torch.nn.Module):
if len(self.multiplier) == 0: if len(self.multiplier) == 0:
# single item, just return it # single item, just return it
return self.multiplier[0] return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
else: else:
# we have a list of multipliers, so we need to get the multiplier for this batch # we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1 # should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier) num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank # match lora_up rank
if len(lora_up.size()) == 2: if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1) multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3: elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1) multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4: elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor return multiplier_tensor
else: else:
return self.multiplier return self.multiplier

View File

@@ -7,9 +7,11 @@ import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm from tqdm import tqdm
from torchvision.transforms import Resize
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict convert_vae_state_dict
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
@@ -180,6 +182,7 @@ class StableDiffusion:
device=self.device_torch, device=self.device_torch,
load_safety_checker=False, load_safety_checker=False,
requires_safety_checker=False, requires_safety_checker=False,
safety_checker=False
).to(self.device_torch) ).to(self.device_torch)
else: else:
pipe = pipln.from_single_file( pipe = pipln.from_single_file(
@@ -189,7 +192,9 @@ class StableDiffusion:
device=self.device_torch, device=self.device_torch,
load_safety_checker=False, load_safety_checker=False,
requires_safety_checker=False, requires_safety_checker=False,
safety_checker=False
).to(self.device_torch) ).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False) pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype)
@@ -379,28 +384,60 @@ class StableDiffusion:
dynamic_crops=False, # look into this dynamic_crops=False, # look into this
dtype=dtype, dtype=dtype,
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
return train_util.concat_embeddings( return prompt_ids
prompt_ids, prompt_ids, bs
)
else: else:
return None return None
def predict_noise( def predict_noise(
self, self,
latents: torch.FloatTensor, latents: torch.Tensor,
text_embeddings: PromptEmbeds, text_embeddings: Union[PromptEmbeds, None] = None,
timestep: int, timestep: Union[int, torch.Tensor] = 1,
guidance_scale=7.5, guidance_scale=7.5,
guidance_rescale=0, # 0.7 guidance_rescale=0, # 0.7 sdxl
add_time_ids=None, add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
**kwargs, **kwargs,
): ):
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
latents.shape[0], # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
if self.is_xl: if self.is_xl:
if add_time_ids is None: if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents) add_time_ids = self.get_time_ids_from_latents(latents)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance:
# todo check this with larget batches
train_util.concat_embeddings(
add_time_ids, add_time_ids, 1
)
else:
# concat to fit batch size
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
@@ -417,20 +454,24 @@ class StableDiffusion:
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
).sample ).sample
# perform guidance if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # perform guidance
noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_text - noise_pred_uncond noise_pred = noise_pred_uncond + guidance_scale * (
) noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0: if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else: else:
# if we are doing classifier free guidance, need to double up if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2) # if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
else:
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
@@ -441,10 +482,12 @@ class StableDiffusion:
encoder_hidden_states=text_embeddings.text_embeds, encoder_hidden_states=text_embeddings.text_embeds,
).sample ).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) if do_classifier_free_guidance:
noise_pred = noise_pred_uncond + guidance_scale * ( # perform guidance
noise_pred_text - noise_pred_uncond noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
) noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred return noise_pred
@@ -495,14 +538,68 @@ class StableDiffusion:
) )
) )
def encode_images(
self,
image_list: List[torch.Tensor],
device=None,
dtype=None
):
if device is None:
device = self.device
if dtype is None:
dtype = self.torch_dtype
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
# resize images if not divisible by 8
for i in range(len(image_list)):
image = image_list[i]
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
images = torch.stack(image_list)
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
latents = latents.to(device, dtype=dtype)
return latents
def encode_image_prompt_pairs(
self,
prompt_list: List[str],
image_list: List[torch.Tensor],
device=None,
dtype=None
):
# todo check image types and expand and rescale as needed
# device and dtype are for outputs
if device is None:
device = self.device
if dtype is None:
dtype = self.torch_dtype
embedding_list = []
latent_list = []
# embed the prompts
for prompt in prompt_list:
embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
return embedding_list, latent_list
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
state_dict = {} state_dict = {}
def update_sd(prefix, sd): def update_sd(prefix, sd):
for k, v in sd.items(): for k, v in sd.items():
key = prefix + k key = prefix + k
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype)) v = v.detach().clone()
state_dict[key] = v state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
# todo see what logit scale is # todo see what logit scale is
if self.is_xl: if self.is_xl:
@@ -536,4 +633,6 @@ class StableDiffusion:
# prepare metadata # prepare metadata
meta = get_meta_for_safetensors(meta) meta = get_meta_for_safetensors(meta)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(state_dict, output_file, metadata=meta) save_file(state_dict, output_file, metadata=meta)

View File

@@ -34,13 +34,16 @@ SCHEDLER_SCHEDULE = "scaled_linear"
def get_torch_dtype(dtype_str): def get_torch_dtype(dtype_str):
# if it is a torch dtype, return it
if isinstance(dtype_str, torch.dtype):
return dtype_str
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
return torch.float return torch.float
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
return torch.float16 return torch.float16
if dtype_str == "bf16" or dtype_str == "bfloat16": if dtype_str == "bf16" or dtype_str == "bfloat16":
return torch.bfloat16 return torch.bfloat16
return None return dtype_str
def replace_filewords_prompt(prompt, args: argparse.Namespace): def replace_filewords_prompt(prompt, args: argparse.Namespace):