Bug fixes and little improvements here and there.

This commit is contained in:
Jaret Burkett
2024-06-08 06:24:20 -06:00
parent 833c833f28
commit 3f3636b788
12 changed files with 358 additions and 117 deletions

View File

@@ -4,7 +4,7 @@ from collections import OrderedDict
from typing import Union, Literal, List, Optional
import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
import torch.functional as F
from safetensors.torch import load_file
@@ -824,6 +824,10 @@ class SDTrainer(BaseSDTrainProcess):
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
del pred_kwargs['down_intrablock_additional_residuals']
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
del pred_kwargs['mid_block_additional_residual']
if can_disable_adapter:
self.adapter.is_active = was_adapter_active
@@ -1065,7 +1069,7 @@ class SDTrainer(BaseSDTrainProcess):
# if prompt_2 is not None:
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
with (network):
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
@@ -1162,26 +1166,27 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
pred_kwargs = {}
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
@@ -1362,6 +1367,32 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
# add_text_embeds is pooled_prompt_embeds for sdxl
added_cond_kwargs = {}
if self.sd.is_xl:
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
down_block_res_samples, mid_block_res_sample = adapter(
noisy_latents,
timesteps,
encoder_hidden_states=conditional_embeds.text_embeds,
controlnet_cond=adapter_images,
conditioning_scale=1.0,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
@@ -1423,10 +1454,10 @@ class SDTrainer(BaseSDTrainProcess):
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
if self.is_bfloat:
loss.backward()
else:
self.scaler.scale(loss).backward()
# if self.is_bfloat:
loss.backward()
# else:
# self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:
@@ -1443,8 +1474,8 @@ class SDTrainer(BaseSDTrainProcess):
self.optimizer.step()
else:
# apply gradients
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.step()
# self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
else:

View File

@@ -10,7 +10,7 @@ from typing import Union, List, Optional
import numpy as np
import yaml
from diffusers import T2IAdapter
from diffusers import T2IAdapter, ControlNetModel
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
@@ -143,7 +143,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -368,6 +368,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
pass
def save(self, step=None):
flush()
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -423,6 +424,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add _lora to name
if self.adapter_config.type == 't2i':
adapter_name += '_t2i'
elif self.adapter_config.type == 'control_net':
adapter_name += '_cn'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
elif self.adapter_config.type.startswith('ip'):
@@ -441,6 +444,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
elif self.adapter_config.type == 'control_net':
# save in diffusers format
name_or_path = file_path.replace('.safetensors', '')
# move it to the new dtype and cpu
orig_device = self.adapter.device
orig_dtype = self.adapter.dtype
self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype))
self.adapter.save_pretrained(
name_or_path,
dtype=get_torch_dtype(self.save_config.dtype),
safe_serialization=True
)
meta_path = os.path.join(name_or_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(self.meta, f)
# move it back
self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
else:
save_ip_adapter_from_diffusers(
state_dict,
@@ -551,6 +571,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
paths = [p for p in paths if '_refiner' not in p]
if '_t2i' not in name:
paths = [p for p in paths if '_t2i' not in p]
if '_cn' not in name:
paths = [p for p in paths if '_cn' not in p]
if len(paths) > 0:
latest_path = max(paths, key=os.path.getctime)
@@ -956,8 +978,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
def setup_adapter(self):
# t2i adapter
is_t2i = self.adapter_config.type == 't2i'
is_control_net = self.adapter_config.type == 'control_net'
if self.adapter_config.type == 't2i':
suffix = 't2i'
elif self.adapter_config.type == 'control_net':
suffix = 'cn'
elif self.adapter_config.type == 'clip':
suffix = 'clip'
elif self.adapter_config.type == 'reference':
@@ -990,6 +1015,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
elif is_control_net:
if self.adapter_config.name_or_path is None:
raise ValueError("ControlNet requires a name_or_path to load from currently")
load_from_path = self.adapter_config.name_or_path
if latest_save_path is not None:
load_from_path = latest_save_path
self.adapter = ControlNetModel.from_pretrained(
load_from_path,
torch_dtype=get_torch_dtype(self.train_config.dtype),
)
elif self.adapter_config.type == 'clip':
self.adapter = ClipVisionAdapter(
sd=self.sd,
@@ -1013,7 +1048,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
if latest_save_path is not None:
if latest_save_path is not None and not is_control_net:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
if is_t2i:
@@ -1040,8 +1075,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
if self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
if latest_save_path is not None and self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
# set trainable params
self.sd.adapter = self.adapter

View File

@@ -1,6 +1,7 @@
import copy
import glob
import os
import shutil
import time
from collections import OrderedDict
@@ -13,6 +14,7 @@ from torch import nn
from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
@@ -25,6 +27,8 @@ from tqdm import tqdm
import time
import numpy as np
from .models.vgg19_critic import Critic
from torchvision.transforms import Resize
import lpips
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -62,6 +66,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
self.optimizer_params = self.get_conf('optimizer_params', {})
@@ -71,6 +76,9 @@ class TrainVAEProcess(BaseTrainProcess):
self.vgg_19 = None
self.style_weight_scalers = []
self.content_weight_scalers = []
self.lpips_loss:lpips.LPIPS = None
self.vae_scale_factor = 8
self.step_num = 0
self.epoch_num = 0
@@ -137,6 +145,15 @@ class TrainVAEProcess(BaseTrainProcess):
num_workers=6
)
def remove_oldest_checkpoint(self):
max_to_keep = 4
folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if len(folders) > max_to_keep:
folders.sort(key=os.path.getmtime)
for folder in folders[:-max_to_keep]:
print(f"Removing {folder}")
shutil.rmtree(folder)
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
@@ -211,7 +228,7 @@ class TrainVAEProcess(BaseTrainProcess):
def get_pattern_loss(self, pred, target):
if self._pattern_loss is None:
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
@@ -226,25 +243,21 @@ class TrainVAEProcess(BaseTrainProcess):
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}.safetensors'
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
filename = f'{self.job.name}{step_num}_diffusers'
state_dict = convert_diffusers_back_to_ldm(self.vae)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(torch.float32)
state_dict[key] = v
# having issues with meta
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
self.vae = self.vae.to("cpu", dtype=torch.float16)
self.vae.save_pretrained(
save_directory=os.path.join(self.save_root, filename)
)
self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
if self.use_critic:
self.critic.save(step)
self.remove_oldest_checkpoint()
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
@@ -280,6 +293,13 @@ class TrainVAEProcess(BaseTrainProcess):
output_img.paste(input_img, (0, 0))
output_img.paste(decoded, (self.resolution, 0))
scale_up = 2
if output_img.height <= 300:
scale_up = 4
# scale up using nearest neighbor
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
step_num = ''
if step is not None:
# zero-pad 9 digits
@@ -294,7 +314,7 @@ class TrainVAEProcess(BaseTrainProcess):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
self.print(f"Looking for latest checkpoint in {self.save_root}")
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
@@ -306,13 +326,14 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
self.vae = AutoencoderKL.from_pretrained(path_to_load)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
def run(self):
super().run()
@@ -374,6 +395,10 @@ class TrainVAEProcess(BaseTrainProcess):
if self.use_critic:
self.critic.setup()
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
optimizer_params=self.optimizer_params)
@@ -397,6 +422,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.sample()
blank_losses = OrderedDict({
"total": [],
"lpips": [],
"style": [],
"content": [],
"mse": [],
@@ -415,17 +441,29 @@ class TrainVAEProcess(BaseTrainProcess):
for batch in self.data_loader:
if self.step_num >= self.max_steps:
break
with torch.no_grad():
batch = batch.to(self.device, dtype=self.torch_dtype)
batch = batch.to(self.device, dtype=self.torch_dtype)
# forward pass
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
latents.requires_grad_(True)
# resize so it matches size of vae evenly
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
# forward pass
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
latents.detach().requires_grad_(True)
pred = self.vae.decode(latents).sample
with torch.no_grad():
show_tensors(
pred.clamp(-1, 1).clone(),
"combined tensor"
)
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, batch], dim=0)
@@ -441,14 +479,31 @@ class TrainVAEProcess(BaseTrainProcess):
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1),
batch.clamp(-1, 1)
).mean() * self.lpips_weight
else:
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
if self.lpips_weight > 0:
max_target = lpips_loss.abs() * 0.1
with torch.no_grad():
crit_g_scaler = 1.0
if critic_gen_loss.abs() > max_target:
crit_g_scaler = max_target / critic_gen_loss.abs()
critic_gen_loss *= crit_g_scaler
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -460,6 +515,8 @@ class TrainVAEProcess(BaseTrainProcess):
loss_value = loss.item()
# get exponent like 3.54e-4
loss_string = f"loss: {loss_value:.2e}"
if self.lpips_weight > 0:
loss_string += f" lpips: {lpips_loss.item():.2e}"
if self.content_weight > 0:
loss_string += f" cnt: {content_loss.item():.2e}"
if self.style_weight > 0:
@@ -496,6 +553,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.progress_bar.update(1)
epoch_losses["total"].append(loss_value)
epoch_losses["lpips"].append(lpips_loss.item())
epoch_losses["style"].append(style_loss.item())
epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item())
@@ -506,6 +564,7 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["crD"].append(critic_d_loss)
log_losses["total"].append(loss_value)
log_losses["lpips"].append(lpips_loss.item())
log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item())

View File

@@ -24,4 +24,5 @@ controlnet_aux==0.0.7
python-dotenv
bitsandbytes
xformers
hf_transfer
hf_transfer
lpips

View File

@@ -7,10 +7,11 @@ from torchvision import transforms
import sys
import os
import cv2
import random
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
import torchvision.transforms.functional
from toolkit.image_utils import show_img
sys.path.append(SD_SCRIPTS_ROOT)
@@ -32,7 +33,7 @@ parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 1024
resolution = 512
bucket_tolerance = 64
batch_size = 1
@@ -41,6 +42,7 @@ batch_size = 1
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
control_path=dataset_folder,
resolution=resolution,
# caption_ext='json',
default_caption='default',
@@ -48,62 +50,135 @@ dataset_config = DatasetConfig(
buckets=True,
bucket_tolerance=bucket_tolerance,
# poi='person',
shuffle_augmentations=True,
# shuffle_augmentations=True,
# augmentations=[
# {
# 'method': 'GaussianBlur',
# 'blur_limit': (1, 16),
# 'sigma_limit': (0, 8),
# 'p': 0.8
# },
# {
# 'method': 'ImageCompression',
# 'quality_lower': 10,
# 'quality_upper': 100,
# 'compression_type': 0,
# 'p': 0.8
# },
# {
# 'method': 'ImageCompression',
# 'quality_lower': 20,
# 'quality_upper': 100,
# 'compression_type': 1,
# 'p': 0.8
# },
# {
# 'method': 'RingingOvershoot',
# 'blur_limit': (3, 35),
# 'cutoff': (0.7, 1.96),
# 'p': 0.8
# },
# {
# 'method': 'GaussNoise',
# 'var_limit': (0, 300),
# 'per_channel': True,
# 'mean': 0.0,
# 'p': 0.8
# },
# {
# 'method': 'GlassBlur',
# 'sigma': 0.6,
# 'max_delta': 7,
# 'iterations': 2,
# 'mode': 'fast',
# 'p': 0.8
# },
# {
# 'method': 'Downscale',
# 'scale_max': 0.5,
# 'interpolation': 'cv2.INTER_CUBIC',
# 'p': 0.8
# 'method': 'Posterize',
# 'num_bits': [(0, 4), (0, 4), (0, 4)],
# 'p': 1.0
# },
#
# ]
)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
if random.random() < p:
kernel_size = random.randint(min_kernel_size, max_kernel_size)
# make sure it is odd
if kernel_size % 2 == 0:
kernel_size += 1
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
return img
def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
ref https://discuss.pytorch.org/t/color-quantization/104528/4
"""
orig_dtype = image.dtype
C, H, W = image.shape
n_colors = palette.shape[0]
# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C]
# Get euclidean distance between each pixel in each column and the column's respective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]
# Get the shortest distance (one value per row (H*W) is selected)
min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W]
# Create a mask for the closest colors
one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors]
# Multiply the mask with the palette colors to get the quantized image
quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C]
# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
return quantized_img.to(orig_dtype)
def color_block_imgs(img, neg1_1=False):
# expects values 0 - 1
orig_dtype = img.dtype
if neg1_1:
img = img * 0.5 + 0.5
img = img * 255
img = img.clamp(0, 255)
img = img.to(torch.uint8)
img_chunks = torch.chunk(img, img.shape[0], dim=0)
posterized_chunks = []
for chunk in img_chunks:
img_size = (chunk.shape[2] + chunk.shape[3]) // 2
# min kernel size of 1% of image, max 10%
min_kernel_size = int(img_size * 0.01)
max_kernel_size = int(img_size * 0.1)
# blur first
chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8)
num_colors = random.randint(1, 16)
resize_to = 16
# chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use)
# mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))]
# shrink the image down to num_colors x num_colors
shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to])
mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))]
colors = shrunk.view(3, -1).T
# remove duplicates
colors = torch.unique(colors, dim=0)
colors = colors.numpy()
colors = colors.tolist()
use_colors = [random.choice(colors) for _ in range(num_colors)]
pallette = torch.tensor([
[0, 0, 0],
mean_color,
[255, 255, 255],
] + use_colors, dtype=torch.float32)
chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0)
# chunk = torchvision.transforms.functional.equalize(chunk)
# color jitter
if random.random() < 0.5:
chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5))
if random.random() < 0.5:
chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0))
# if random.random() < 0.5:
# chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5))
chunk = random_blur(chunk, p=0.6)
posterized_chunks.append(chunk)
img = torch.cat(posterized_chunks, dim=0)
img = img.to(orig_dtype)
img = img / 255
if neg1_1:
img = img * 2 - 1
return img
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
@@ -112,11 +187,19 @@ for epoch in range(args.epochs):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
img_batch = color_block_imgs(img_batch, neg1_1=True)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0)
big_control_img = torch.cat(control_chunks, dim=3)
big_control_img = big_control_img.squeeze(0) * 2 - 1
big_img = torch.cat([big_img, big_control_img], dim=2)
min_val = big_img.min()
max_val = big_img.max()
@@ -127,7 +210,7 @@ for epoch in range(args.epochs):
show_img(img)
# time.sleep(1.0)
time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)

View File

@@ -129,13 +129,13 @@ class NetworkConfig:
self.conv = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
@@ -530,6 +530,8 @@ class DatasetConfig:
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
self.extra_values: List[float] = kwargs.get('extra_values', [])
self.square_crop: bool = kwargs.get('square_crop', False)
# apply same augmentations to control images. Usually want this true unless special case
self.replay_transforms: bool = kwargs.get('replay_transforms', True)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -860,7 +860,8 @@ class AugmentationFileItemDTOMixin:
# only store the spatial transforms
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
self.aug_replay_spatial_transforms = augmented_params
if self.dataset_config.replay_transforms:
self.aug_replay_spatial_transforms = augmented_params
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)

View File

@@ -240,7 +240,7 @@ def get_direct_guidance_loss(
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
guidance_scale = 1.0
guidance_scale = 1.25
guidance_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
@@ -586,6 +586,8 @@ def get_guided_tnt(
loss = prior_loss + this_loss - that_loss
loss = loss.mean()
loss.backward()
# detach it so parent class can run backward on no grads without throwing error

View File

@@ -1,5 +1,5 @@
import torch
from transformers import Adafactor
from transformers import Adafactor, AdamW
def get_optimizer(
@@ -69,7 +69,7 @@ def get_optimizer(
if 'relative_step' not in optimizer_params:
optimizer_params['relative_step'] = False
if 'scale_parameter' not in optimizer_params:
optimizer_params['scale_parameter'] = True
optimizer_params['scale_parameter'] = False
if 'warmup_init' not in optimizer_params:
optimizer_params['warmup_init'] = False
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)

View File

@@ -39,7 +39,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -497,6 +498,12 @@ class StableDiffusion:
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
elif isinstance(self.adapter, ControlNetModel):
if self.is_xl:
Pipe = StableDiffusionXLControlNetPipeline
else:
Pipe = StableDiffusionControlNetPipeline
extra_args['controlnet'] = self.adapter
elif isinstance(self.adapter, ReferenceAdapter):
# pass the noise scheduler to the adapter
self.adapter.noise_scheduler = noise_scheduler
@@ -588,6 +595,10 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, ControlNetModel):
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['image'] = validation_image
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),
@@ -967,6 +978,16 @@ class StableDiffusion:
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
# handle controlnet
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
# go through each item and concat if doing cfg and it doesnt have the same shape
for idx, item in enumerate(kwargs['down_block_additional_residuals']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
for idx, item in enumerate(kwargs['mid_block_additional_residual']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
@@ -1383,11 +1404,13 @@ class StableDiffusion:
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
# resize images if not divisible by 8
for i in range(len(image_list)):
image = image_list[i]
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
images = torch.stack(image_list)
if isinstance(self.vae, AutoencoderTiny):
@@ -1756,6 +1779,9 @@ class StableDiffusion:
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
elif isinstance(self.adapter, ControlNetModel):
requires_grad = self.adapter.conv_in.training
adapter_device = self.adapter.device
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device

View File

@@ -158,7 +158,7 @@ def get_style_model_and_losses(
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
content_layers = ['conv2_2', 'conv3_2', 'conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype

View File

@@ -81,7 +81,8 @@ def step_adafactor(self, closure=None):
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"]
update = (grad ** 2) + eps
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]