Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it

This commit is contained in:
Jaret Burkett
2023-09-16 17:41:07 -06:00
parent 27f343fc08
commit c698837241
11 changed files with 214 additions and 78 deletions

View File

@@ -3,10 +3,12 @@ import glob
import os
import time
from collections import OrderedDict
from typing import List, Optional
from PIL import Image
from PIL.ImageOps import exif_transpose
# from basicsr.archs.rrdbnet_arch import RRDBNet
from toolkit.basic import flush
from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
@@ -67,9 +69,10 @@ class TrainESRGANProcess(BaseTrainProcess):
self.augmentations = self.get_conf('augmentations', {})
self.torch_dtype = get_torch_dtype(self.dtype)
if self.torch_dtype == torch.bfloat16:
self.esrgan_dtype = torch.float16
self.esrgan_dtype = torch.float32
else:
self.esrgan_dtype = torch.float32
self.vgg_19 = None
self.style_weight_scalers = []
self.content_weight_scalers = []
@@ -232,6 +235,7 @@ class TrainESRGANProcess(BaseTrainProcess):
pattern_size=self.zoom,
dtype=self.torch_dtype
).to(self.device, dtype=self.torch_dtype)
self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
@@ -269,13 +273,52 @@ class TrainESRGANProcess(BaseTrainProcess):
if self.use_critic:
self.critic.save(step)
def sample(self, step=None):
def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True)
batch_sample_folder = os.path.join(self.save_root, 'samples_batch')
batch_targets = None
batch_inputs = None
if batch is not None and not os.path.exists(batch_sample_folder):
os.makedirs(batch_sample_folder, exist_ok=True)
self.model.eval()
def process_and_save(img, target_img, save_path):
output = self.model(img.to(self.device, dtype=self.esrgan_dtype))
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8))
if isinstance(target_img, torch.Tensor):
# convert to pil
target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
target_img = Image.fromarray((target_img * 255).astype(np.uint8))
# upscale to size * self.upscale_sample while maintaining pixels
output = output.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size
# stack input image and decoded image
target_image = target_img.resize((width, height))
output = output.resize((width, height))
output_img = Image.new('RGB', (width * 2, height))
output_img.paste(target_image, (0, 0))
output_img.paste(output, (width, 0))
output_img.save(save_path)
with torch.no_grad():
for i, img_url in enumerate(self.sample_sources):
img = exif_transpose(Image.open(img_url))
@@ -295,30 +338,6 @@ class TrainESRGANProcess(BaseTrainProcess):
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
img = img
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8))
# upscale to size * self.upscale_sample while maintaining pixels
output = output.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size
# stack input image and decoded image
target_image = target_image.resize((width, height))
output = output.resize((width, height))
output_img = Image.new('RGB', (width * 2, height))
output_img.paste(target_image, (0, 0))
output_img.paste(output, (width, 0))
step_num = ''
if step is not None:
@@ -328,7 +347,23 @@ class TrainESRGANProcess(BaseTrainProcess):
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))
process_and_save(img, target_image, os.path.join(sample_folder, filename))
if batch is not None:
batch_targets = batch[0].detach()
batch_inputs = batch[1].detach()
batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0)
batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0)
for i in range(len(batch_inputs)):
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
self.model.train()
@@ -445,35 +480,60 @@ class TrainESRGANProcess(BaseTrainProcess):
print("Generating baseline samples")
self.sample(step=0)
# range start at self.epoch_num go to self.epochs
critic_losses = []
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
flush()
for targets, inputs in self.data_loader:
if self.step_num >= self.max_steps:
break
with torch.no_grad():
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
is_critic_only_step = False
if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform():
is_critic_only_step = True
pred = self.model(inputs)
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
optimizer.zero_grad()
# dont do grads here for critic step
do_grad = not is_critic_only_step
with torch.set_grad_enabled(do_grad):
pred = self.model(inputs)
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, targets], dim=0)
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
stacked = stacked.clamp(0, 1)
self.vgg_19(stacked)
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
if torch.isnan(pred).any():
raise ValueError('pred has nan values')
if torch.isnan(targets).any():
raise ValueError('targets has nan values')
if self.use_critic:
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, targets], dim=0)
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
stacked = stacked.clamp(0, 1)
self.vgg_19(stacked)
# make sure we dont have nans
if torch.isnan(self.vgg19_pool_4.tensor).any():
raise ValueError('vgg19_pool_4 has nan values')
if is_critic_only_step:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
critic_losses.append(critic_d_loss)
# don't do generator step
continue
else:
critic_d_loss = 0.0
# doing a regular step
if len(critic_losses) == 0:
critic_d_loss = 0
else:
critic_d_loss = sum(critic_losses) / len(critic_losses)
style_loss = self.get_style_loss() * self.style_weight
content_loss = self.get_content_loss() * self.content_weight
mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
@@ -483,10 +543,13 @@ class TrainESRGANProcess(BaseTrainProcess):
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
# make sure non nan
if torch.isnan(loss):
raise ValueError('loss is nan')
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
@@ -549,7 +612,7 @@ class TrainESRGANProcess(BaseTrainProcess):
if self.sample_every and self.step_num % self.sample_every == 0:
# print above the progress bar
self.print(f"Sampling at step {self.step_num}")
self.sample(self.step_num)
self.sample(self.step_num, batch=[targets, inputs])
if self.save_every and self.step_num % self.save_every == 0:
# print above the progress bar

View File

@@ -154,28 +154,28 @@ class Critic:
# train critic here
self.model.train()
self.model.requires_grad_(True)
self.optimizer.zero_grad()
critic_losses = []
for i in range(self.num_critic_per_gen):
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs)
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
stacked_output = self.model(inputs).float()
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
self.optimizer.zero_grad()
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# avg loss
loss = np.mean(critic_losses)

6
run.py
View File

@@ -5,6 +5,12 @@ from typing import Union, OrderedDict
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports
# import toolkit.cuda_malloc
# check if we have DEBUG_TOOLKIT in env
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
# set torch to trace mode
import torch
torch.autograd.set_detect_anomaly(True)
import argparse
from toolkit.job import get_job

View File

@@ -3,6 +3,8 @@ import time
from typing import List, Optional, Literal, Union
import random
from toolkit.prompt_utils import PromptEmbeds
ImgExt = Literal['jpg', 'png', 'webp']
@@ -447,3 +449,11 @@ class GenerateImageConfig:
self.network_multiplier = float(content)
elif flag == 'gr':
self.guidance_rescale = float(content)
def post_process_embeddings(
self,
conditional_prompt_embeds: PromptEmbeds,
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
):
# this is called after prompt embeds are encoded. We can override them in the future here
pass

View File

@@ -53,6 +53,8 @@ class ImageDataset(Dataset, CaptionMixin):
else:
bad_count += 1
self.file_list = new_file_list
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.path}"
@@ -90,7 +92,10 @@ class ImageDataset(Dataset, CaptionMixin):
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
scaler = scale_size / min_img_size
scale_width = int((img.width + 5) * scaler)
scale_height = int((img.height + 5) * scaler)
img = img.resize((scale_width, scale_height), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)

View File

@@ -7,7 +7,7 @@ import itertools
class LosslessLatentDecoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
super(LosslessLatentDecoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels // (latent_depth * latent_depth))
@@ -46,7 +46,7 @@ class LosslessLatentDecoder(nn.Module):
class LosslessLatentEncoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
super(LosslessLatentEncoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels * (latent_depth * latent_depth))
@@ -108,7 +108,7 @@ if __name__ == '__main__':
from PIL import Image
import torchvision.transforms as transforms
user_path = os.path.expanduser('~')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")

View File

@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
real = real.float()
fake = fake.float()
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
if torch.isnan(interpolates).any():
print('d_interpolates is nan')
d_interpolates = critic(interpolates)
fake = torch.ones(real.size(0), 1, device=device)
if torch.isnan(d_interpolates).any():
print('fake is nan')
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
only_inputs=True,
)[0]
# see if any are nan
if torch.isnan(gradients).any():
print('gradients is nan')
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
return gradient_penalty.float()
class PatternLoss(torch.nn.Module):

View File

@@ -44,6 +44,12 @@ class PromptEmbeds:
self.pooled_embeds = self.pooled_embeds.detach()
return self
def clone(self):
if self.pooled_embeds is not None:
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
else:
return PromptEmbeds(self.text_embeds.clone())
class EncodedPromptPair:
def __init__(

View File

@@ -368,6 +368,19 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(
conditional_embeds,
unconditional_embeds,
)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
@@ -382,10 +395,14 @@ class StableDiffusion:
extra['use_karras_sigmas'] = True
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
# prompt=gen_config.prompt,
# prompt_2=gen_config.prompt_2,
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
# negative_prompt=gen_config.negative_prompt,
# negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
@@ -395,8 +412,10 @@ class StableDiffusion:
).images[0]
else:
img = pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
# prompt=gen_config.prompt,
prompt_embeds=conditional_embeds.text_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
# negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
@@ -625,21 +644,25 @@ class StableDiffusion:
# return latents_steps
return latents
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
# if it is not a list, make it one
if not isinstance(prompt, list):
prompt = [prompt]
if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2]
if self.is_xl:
return PromptEmbeds(
train_tools.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
prompt2,
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=self.use_text_encoder_1,
use_text_encoder_2=self.use_text_encoder_2,
use_text_encoder_1=self.use_text_encoder_1 or force_all,
use_text_encoder_2=self.use_text_encoder_2 or force_all,
)
)
else:

View File

@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
# Define the separate loss function
def separated_loss(y_pred, y_true):
y_pred = y_pred.float()
y_true = y_true.float()
diff = torch.abs(y_pred - y_true)
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
return 2. * l2 / content_size
# Calculate itemized loss
pred_itemized_loss = separated_loss(pred_layer, target_layer)
# check if is nan
if torch.isnan(pred_itemized_loss).any():
print('pred_itemized_loss is nan')
# Calculate the mean of itemized loss
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
def convert_to_gram_matrix(inputs):
inputs = inputs.float()
shape = inputs.size()
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
size = height * width * filters
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
target_grams = convert_to_gram_matrix(style_target)
pred_grams = convert_to_gram_matrix(preds)
itemized_loss = separated_loss(pred_grams, target_grams)
# check if is nan
if torch.isnan(itemized_loss).any():
print('itemized_loss is nan')
# reshape itemized loss to be (batch, 1, 1, 1)
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss.to(input_dtype)
self.loss = loss.to(input_dtype).float()
return stacked_input.to(input_dtype)
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv4_2']
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype

View File

@@ -479,6 +479,7 @@ def encode_prompts_xl(
tokenizers: list['CLIPTokenizer'],
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
prompts: list[str],
prompts2: Union[list[str], None],
num_images_per_prompt: int = 1,
use_text_encoder_1: bool = True, # sdxl
use_text_encoder_2: bool = True # sdxl
@@ -486,11 +487,13 @@ def encode_prompts_xl(
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
if prompts2 is None:
prompts2 = prompts
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
# todo, we are using a blank string to ignore that encoder for now.
# find a better way to do this (zeroing?, removing it from the unet?)
prompt_list_to_use = prompts
prompt_list_to_use = prompts if idx == 0 else prompts2
if idx == 0 and not use_text_encoder_1:
prompt_list_to_use = ["" for _ in prompts]
if idx == 1 and not use_text_encoder_2:
@@ -515,6 +518,7 @@ def text_encode(text_encoder: 'CLIPTextModel', tokens):
return text_encoder(tokens.to(text_encoder.device))[0]
def encode_prompts(
tokenizer: 'CLIPTokenizer',
text_encoder: 'CLIPTokenizer',