mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it
This commit is contained in:
@@ -3,10 +3,12 @@ import glob
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
# from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
@@ -67,9 +69,10 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
self.augmentations = self.get_conf('augmentations', {})
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
if self.torch_dtype == torch.bfloat16:
|
||||
self.esrgan_dtype = torch.float16
|
||||
self.esrgan_dtype = torch.float32
|
||||
else:
|
||||
self.esrgan_dtype = torch.float32
|
||||
|
||||
self.vgg_19 = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
@@ -232,6 +235,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
pattern_size=self.zoom,
|
||||
dtype=self.torch_dtype
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
|
||||
@@ -269,13 +273,52 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
def sample(self, step=None):
|
||||
def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
os.makedirs(sample_folder, exist_ok=True)
|
||||
batch_sample_folder = os.path.join(self.save_root, 'samples_batch')
|
||||
|
||||
batch_targets = None
|
||||
batch_inputs = None
|
||||
if batch is not None and not os.path.exists(batch_sample_folder):
|
||||
os.makedirs(batch_sample_folder, exist_ok=True)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def process_and_save(img, target_img, save_path):
|
||||
output = self.model(img.to(self.device, dtype=self.esrgan_dtype))
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
|
||||
if isinstance(target_img, torch.Tensor):
|
||||
# convert to pil
|
||||
target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
target_img = Image.fromarray((target_img * 255).astype(np.uint8))
|
||||
|
||||
# upscale to size * self.upscale_sample while maintaining pixels
|
||||
output = output.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_img.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 2, height))
|
||||
output_img.paste(target_image, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
|
||||
output_img.save(save_path)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, img_url in enumerate(self.sample_sources):
|
||||
img = exif_transpose(Image.open(img_url))
|
||||
@@ -295,30 +338,6 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
|
||||
img = img
|
||||
output = self.model(img)
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
|
||||
# upscale to size * self.upscale_sample while maintaining pixels
|
||||
output = output.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_image.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 2, height))
|
||||
output_img.paste(target_image, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
@@ -328,7 +347,23 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
output_img.save(os.path.join(sample_folder, filename))
|
||||
process_and_save(img, target_image, os.path.join(sample_folder, filename))
|
||||
|
||||
if batch is not None:
|
||||
batch_targets = batch[0].detach()
|
||||
batch_inputs = batch[1].detach()
|
||||
batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0)
|
||||
batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0)
|
||||
|
||||
for i in range(len(batch_inputs)):
|
||||
if step is not None:
|
||||
# zero-pad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
|
||||
|
||||
self.model.train()
|
||||
|
||||
@@ -445,35 +480,60 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
print("Generating baseline samples")
|
||||
self.sample(step=0)
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
critic_losses = []
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
flush()
|
||||
for targets, inputs in self.data_loader:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
is_critic_only_step = False
|
||||
if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform():
|
||||
is_critic_only_step = True
|
||||
|
||||
pred = self.model(inputs)
|
||||
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
|
||||
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
|
||||
|
||||
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
optimizer.zero_grad()
|
||||
# dont do grads here for critic step
|
||||
do_grad = not is_critic_only_step
|
||||
with torch.set_grad_enabled(do_grad):
|
||||
pred = self.model(inputs)
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, targets], dim=0)
|
||||
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
stacked = stacked.clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
if torch.isnan(pred).any():
|
||||
raise ValueError('pred has nan values')
|
||||
if torch.isnan(targets).any():
|
||||
raise ValueError('targets has nan values')
|
||||
|
||||
if self.use_critic:
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, targets], dim=0)
|
||||
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
stacked = stacked.clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
# make sure we dont have nans
|
||||
if torch.isnan(self.vgg19_pool_4.tensor).any():
|
||||
raise ValueError('vgg19_pool_4 has nan values')
|
||||
|
||||
if is_critic_only_step:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
critic_losses.append(critic_d_loss)
|
||||
# don't do generator step
|
||||
continue
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
# doing a regular step
|
||||
if len(critic_losses) == 0:
|
||||
critic_d_loss = 0
|
||||
else:
|
||||
critic_d_loss = sum(critic_losses) / len(critic_losses)
|
||||
|
||||
style_loss = self.get_style_loss() * self.style_weight
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
|
||||
mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
|
||||
@@ -483,10 +543,13 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
# make sure non nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError('loss is nan')
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
@@ -549,7 +612,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
if self.sample_every and self.step_num % self.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Sampling at step {self.step_num}")
|
||||
self.sample(self.step_num)
|
||||
self.sample(self.step_num, batch=[targets, inputs])
|
||||
|
||||
if self.save_every and self.step_num % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
|
||||
@@ -154,28 +154,28 @@ class Critic:
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
stacked_output = self.model(inputs).float()
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
|
||||
6
run.py
6
run.py
@@ -5,6 +5,12 @@ from typing import Union, OrderedDict
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
# import toolkit.cuda_malloc
|
||||
|
||||
# check if we have DEBUG_TOOLKIT in env
|
||||
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
||||
# set torch to trace mode
|
||||
import torch
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import time
|
||||
from typing import List, Optional, Literal, Union
|
||||
import random
|
||||
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
|
||||
@@ -447,3 +449,11 @@ class GenerateImageConfig:
|
||||
self.network_multiplier = float(content)
|
||||
elif flag == 'gr':
|
||||
self.guidance_rescale = float(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
conditional_prompt_embeds: PromptEmbeds,
|
||||
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
@@ -53,6 +53,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
else:
|
||||
bad_count += 1
|
||||
|
||||
self.file_list = new_file_list
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
||||
@@ -90,7 +92,10 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
scaler = scale_size / min_img_size
|
||||
scale_width = int((img.width + 5) * scaler)
|
||||
scale_height = int((img.height + 5) * scaler)
|
||||
img = img.resize((scale_width, scale_height), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
|
||||
@@ -7,7 +7,7 @@ import itertools
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
@@ -46,7 +46,7 @@ class LosslessLatentDecoder(nn.Module):
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
@@ -108,7 +108,7 @@ if __name__ == '__main__':
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
|
||||
|
||||
@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
|
||||
# Gradient penalty
|
||||
def get_gradient_penalty(critic, real, fake, device):
|
||||
with torch.autocast(device_type='cuda'):
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
|
||||
real = real.float()
|
||||
fake = fake.float()
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
|
||||
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
|
||||
if torch.isnan(interpolates).any():
|
||||
print('d_interpolates is nan')
|
||||
d_interpolates = critic(interpolates)
|
||||
fake = torch.ones(real.size(0), 1, device=device)
|
||||
|
||||
|
||||
if torch.isnan(d_interpolates).any():
|
||||
print('fake is nan')
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=d_interpolates,
|
||||
inputs=interpolates,
|
||||
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
# see if any are nan
|
||||
if torch.isnan(gradients).any():
|
||||
print('gradients is nan')
|
||||
|
||||
gradients = gradients.view(gradients.size(0), -1)
|
||||
gradient_norm = gradients.norm(2, dim=1)
|
||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||
return gradient_penalty
|
||||
return gradient_penalty.float()
|
||||
|
||||
|
||||
class PatternLoss(torch.nn.Module):
|
||||
|
||||
@@ -44,6 +44,12 @@ class PromptEmbeds:
|
||||
self.pooled_embeds = self.pooled_embeds.detach()
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
if self.pooled_embeds is not None:
|
||||
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
|
||||
else:
|
||||
return PromptEmbeds(self.text_embeds.clone())
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
|
||||
@@ -368,6 +368,19 @@ class StableDiffusion:
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
@@ -382,10 +395,14 @@ class StableDiffusion:
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
# prompt=gen_config.prompt,
|
||||
# prompt_2=gen_config.prompt_2,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
# negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -395,8 +412,10 @@ class StableDiffusion:
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
# prompt=gen_config.prompt,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -625,21 +644,25 @@ class StableDiffusion:
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
# if it is not a list, make it one
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
if self.is_xl:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
prompt2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=self.use_text_encoder_1,
|
||||
use_text_encoder_2=self.use_text_encoder_2,
|
||||
use_text_encoder_1=self.use_text_encoder_1 or force_all,
|
||||
use_text_encoder_2=self.use_text_encoder_2 or force_all,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
|
||||
|
||||
# Define the separate loss function
|
||||
def separated_loss(y_pred, y_true):
|
||||
y_pred = y_pred.float()
|
||||
y_true = y_true.float()
|
||||
diff = torch.abs(y_pred - y_true)
|
||||
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
|
||||
return 2. * l2 / content_size
|
||||
|
||||
# Calculate itemized loss
|
||||
pred_itemized_loss = separated_loss(pred_layer, target_layer)
|
||||
# check if is nan
|
||||
if torch.isnan(pred_itemized_loss).any():
|
||||
print('pred_itemized_loss is nan')
|
||||
|
||||
# Calculate the mean of itemized loss
|
||||
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
|
||||
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
|
||||
|
||||
|
||||
def convert_to_gram_matrix(inputs):
|
||||
inputs = inputs.float()
|
||||
shape = inputs.size()
|
||||
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
|
||||
size = height * width * filters
|
||||
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
|
||||
target_grams = convert_to_gram_matrix(style_target)
|
||||
pred_grams = convert_to_gram_matrix(preds)
|
||||
itemized_loss = separated_loss(pred_grams, target_grams)
|
||||
# check if is nan
|
||||
if torch.isnan(itemized_loss).any():
|
||||
print('itemized_loss is nan')
|
||||
# reshape itemized loss to be (batch, 1, 1, 1)
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss.to(input_dtype)
|
||||
self.loss = loss.to(input_dtype).float()
|
||||
return stacked_input.to(input_dtype)
|
||||
|
||||
|
||||
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv4_2']
|
||||
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
|
||||
@@ -479,6 +479,7 @@ def encode_prompts_xl(
|
||||
tokenizers: list['CLIPTokenizer'],
|
||||
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
|
||||
prompts: list[str],
|
||||
prompts2: Union[list[str], None],
|
||||
num_images_per_prompt: int = 1,
|
||||
use_text_encoder_1: bool = True, # sdxl
|
||||
use_text_encoder_2: bool = True # sdxl
|
||||
@@ -486,11 +487,13 @@ def encode_prompts_xl(
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
pooled_text_embeds = None # always text_encoder_2's pool
|
||||
if prompts2 is None:
|
||||
prompts2 = prompts
|
||||
|
||||
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
|
||||
# todo, we are using a blank string to ignore that encoder for now.
|
||||
# find a better way to do this (zeroing?, removing it from the unet?)
|
||||
prompt_list_to_use = prompts
|
||||
prompt_list_to_use = prompts if idx == 0 else prompts2
|
||||
if idx == 0 and not use_text_encoder_1:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
@@ -515,6 +518,7 @@ def text_encode(text_encoder: 'CLIPTextModel', tokens):
|
||||
return text_encoder(tokens.to(text_encoder.device))[0]
|
||||
|
||||
|
||||
|
||||
def encode_prompts(
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTokenizer',
|
||||
|
||||
Reference in New Issue
Block a user