Various bug fixes and improvements

This commit is contained in:
Jaret Burkett
2023-08-12 05:59:50 -06:00
parent 67dfd9ced0
commit 379992d89e
5 changed files with 180 additions and 93 deletions

View File

@@ -2,11 +2,12 @@ import copy
import random
from collections import OrderedDict
import os
from contextlib import nullcontext
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.data_loader import PairedImageDataset
from toolkit.prompt_utils import concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import gc
from toolkit import train_tools
@@ -80,34 +81,16 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
imgs, prompts = batch
dtype = get_torch_dtype(self.train_config.dtype)
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
# split batched images in half so left is negative and right is positive
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
positive_latents = self.sd.encode_images(positive_images)
negative_latents = self.sd.encode_images(negative_images)
height = positive_images.shape[2]
width = positive_images.shape[3]
batch_size = positive_images.shape[0]
# encode the images
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
positive_latents = positive_latents * 0.18215
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
negative_latents = negative_latents * 0.18215
embedding_list = []
negative_embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
# just empty for now
# todo cache this?
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
negative_embedding_list.append(negative_embed)
conditional_embeds = concat_prompt_embeds(embedding_list)
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
@@ -115,26 +98,12 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(neg, pos, gs, cts, dn):
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
neg, # negative prompt
pos, # positive prompt
self.train_config.batch_size,
),
timestep=cts,
guidance_scale=gs,
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
@@ -147,6 +116,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
if do_mirror_loss:
# mirror the noise
# torch shape is [batch, channels, height, width]
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
else:
noise_negative = noise_positive.clone()
@@ -159,8 +129,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
network_multiplier = [1.0, -1.0]
flush()
@@ -170,22 +138,31 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss_mirror_float = None
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
# text encoding
embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
with self.network:
assert self.network.is_active
loss_list = []
# do positive first
self.network.multiplier = network_multiplier
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
noise_pred = self.sd.predict_noise(
latents=noisy_latents,
conditional_embeddings=conditional_embeds,
timestep=timesteps,
)
if self.sd.is_v2: # check is vpred, don't want to track it down right now
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
@@ -199,7 +176,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = loss.mean()
loss_slide_float = loss.item()
if do_mirror_loss:
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative
@@ -221,7 +197,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
optimizer.step()
lr_scheduler.step()
# reset network
self.network.multiplier = 1.0

View File

@@ -9,17 +9,21 @@ config:
# for tensorboard logging
log_dir: "/home/jaret/Dev/.tensorboard"
network:
type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers
rank: 16
alpha: 8
type: "lora"
linear: 64
linear_alpha: 32
conv: 32
conv_alpha: 16
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 1000
lr: 5e-5
steps: 5000
lr: 1e-4
train_unet: true
gradient_checkpointing: true
train_text_encoder: false
optimizer: "lion8bit"
train_text_encoder: true
optimizer: "adamw"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 1
@@ -36,11 +40,11 @@ config:
is_v_pred: false # for v-prediction models (most v2 models)
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
save_every: 1000 # save every this many steps
max_step_saves_to_keep: 2 # only affects step counts
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 20 # sample every this many steps
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
@@ -81,6 +85,8 @@ config:
- 512
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
target_class: "photo of a person"
# additional_losses:
# - "mirror"
meta: