mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 21:19:56 +00:00
Bug fixes and little improvements here and there.
This commit is contained in:
@@ -4,7 +4,7 @@ from collections import OrderedDict
|
||||
from typing import Union, Literal, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter, AutoencoderTiny
|
||||
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
|
||||
|
||||
import torch.functional as F
|
||||
from safetensors.torch import load_file
|
||||
@@ -824,6 +824,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
|
||||
del pred_kwargs['mid_block_additional_residual']
|
||||
|
||||
if can_disable_adapter:
|
||||
self.adapter.is_active = was_adapter_active
|
||||
@@ -1065,7 +1069,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if prompt_2 is not None:
|
||||
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
with (network):
|
||||
# encode clip adapter here so embeds are active for tokenizer
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
with self.timer('encode_clip_vision_embeds'):
|
||||
@@ -1162,26 +1166,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
@@ -1362,6 +1367,32 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.do_cfg:
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if self.train_config.do_cfg:
|
||||
raise ValueError("ControlNetModel is not supported with CFG")
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
# add_text_embeds is pooled_prompt_embeds for sdxl
|
||||
added_cond_kwargs = {}
|
||||
if self.sd.is_xl:
|
||||
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
|
||||
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
|
||||
down_block_res_samples, mid_block_res_sample = adapter(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=conditional_embeds.text_embeds,
|
||||
controlnet_cond=adapter_images,
|
||||
conditioning_scale=1.0,
|
||||
guess_mode=False,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
@@ -1423,10 +1454,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
if self.is_bfloat:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
# if self.is_bfloat:
|
||||
loss.backward()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
@@ -1443,8 +1474,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.step()
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Union, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from diffusers import T2IAdapter
|
||||
from diffusers import T2IAdapter, ControlNetModel
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -143,7 +143,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
@@ -368,6 +368,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
flush()
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
@@ -423,6 +424,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# add _lora to name
|
||||
if self.adapter_config.type == 't2i':
|
||||
adapter_name += '_t2i'
|
||||
elif self.adapter_config.type == 'control_net':
|
||||
adapter_name += '_cn'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
adapter_name += '_clip'
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
@@ -441,6 +444,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
elif self.adapter_config.type == 'control_net':
|
||||
# save in diffusers format
|
||||
name_or_path = file_path.replace('.safetensors', '')
|
||||
# move it to the new dtype and cpu
|
||||
orig_device = self.adapter.device
|
||||
orig_dtype = self.adapter.dtype
|
||||
self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype))
|
||||
self.adapter.save_pretrained(
|
||||
name_or_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
safe_serialization=True
|
||||
)
|
||||
meta_path = os.path.join(name_or_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(self.meta, f)
|
||||
# move it back
|
||||
self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
|
||||
else:
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
@@ -551,6 +571,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
paths = [p for p in paths if '_refiner' not in p]
|
||||
if '_t2i' not in name:
|
||||
paths = [p for p in paths if '_t2i' not in p]
|
||||
if '_cn' not in name:
|
||||
paths = [p for p in paths if '_cn' not in p]
|
||||
|
||||
if len(paths) > 0:
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
@@ -956,8 +978,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def setup_adapter(self):
|
||||
# t2i adapter
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
is_control_net = self.adapter_config.type == 'control_net'
|
||||
if self.adapter_config.type == 't2i':
|
||||
suffix = 't2i'
|
||||
elif self.adapter_config.type == 'control_net':
|
||||
suffix = 'cn'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
suffix = 'clip'
|
||||
elif self.adapter_config.type == 'reference':
|
||||
@@ -990,6 +1015,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
elif is_control_net:
|
||||
if self.adapter_config.name_or_path is None:
|
||||
raise ValueError("ControlNet requires a name_or_path to load from currently")
|
||||
load_from_path = self.adapter_config.name_or_path
|
||||
if latest_save_path is not None:
|
||||
load_from_path = latest_save_path
|
||||
self.adapter = ControlNetModel.from_pretrained(
|
||||
load_from_path,
|
||||
torch_dtype=get_torch_dtype(self.train_config.dtype),
|
||||
)
|
||||
elif self.adapter_config.type == 'clip':
|
||||
self.adapter = ClipVisionAdapter(
|
||||
sd=self.sd,
|
||||
@@ -1013,7 +1048,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
if latest_save_path is not None:
|
||||
if latest_save_path is not None and not is_control_net:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
if is_t2i:
|
||||
@@ -1040,8 +1075,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
if self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
if latest_save_path is not None and self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
self.sd.adapter = self.adapter
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -13,6 +14,7 @@ from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.image_utils import show_tensors
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
|
||||
@@ -25,6 +27,8 @@ from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
from .models.vgg19_critic import Critic
|
||||
from torchvision.transforms import Resize
|
||||
import lpips
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -62,6 +66,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
@@ -71,6 +76,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vgg_19 = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
self.lpips_loss:lpips.LPIPS = None
|
||||
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
@@ -137,6 +145,15 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
num_workers=6
|
||||
)
|
||||
|
||||
def remove_oldest_checkpoint(self):
|
||||
max_to_keep = 4
|
||||
folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
|
||||
if len(folders) > max_to_keep:
|
||||
folders.sort(key=os.path.getmtime)
|
||||
for folder in folders[:-max_to_keep]:
|
||||
print(f"Removing {folder}")
|
||||
shutil.rmtree(folder)
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
|
||||
@@ -211,7 +228,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def get_pattern_loss(self, pred, target):
|
||||
if self._pattern_loss is None:
|
||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
|
||||
self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
|
||||
dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
@@ -226,25 +243,21 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
self.update_training_metadata()
|
||||
filename = f'{self.job.name}{step_num}.safetensors'
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
filename = f'{self.job.name}{step_num}_diffusers'
|
||||
|
||||
state_dict = convert_diffusers_back_to_ldm(self.vae)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(torch.float32)
|
||||
state_dict[key] = v
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
|
||||
self.vae = self.vae.to("cpu", dtype=torch.float16)
|
||||
self.vae.save_pretrained(
|
||||
save_directory=os.path.join(self.save_root, filename)
|
||||
)
|
||||
self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
self.remove_oldest_checkpoint()
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
@@ -280,6 +293,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
output_img.paste(input_img, (0, 0))
|
||||
output_img.paste(decoded, (self.resolution, 0))
|
||||
|
||||
scale_up = 2
|
||||
if output_img.height <= 300:
|
||||
scale_up = 4
|
||||
|
||||
# scale up using nearest neighbor
|
||||
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zero-pad 9 digits
|
||||
@@ -294,7 +314,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
path_to_load = self.vae_path
|
||||
# see if we have a checkpoint in out output to resume from
|
||||
self.print(f"Looking for latest checkpoint in {self.save_root}")
|
||||
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
|
||||
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
@@ -306,13 +326,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f"Loading VAE")
|
||||
self.print(f" - Loading VAE: {path_to_load}")
|
||||
if self.vae is None:
|
||||
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
|
||||
self.vae = AutoencoderKL.from_pretrained(path_to_load)
|
||||
|
||||
# set decoder to train
|
||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
self.vae.decoder.train()
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -374,6 +395,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
|
||||
@@ -397,6 +422,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.sample()
|
||||
blank_losses = OrderedDict({
|
||||
"total": [],
|
||||
"lpips": [],
|
||||
"style": [],
|
||||
"content": [],
|
||||
"mse": [],
|
||||
@@ -415,17 +441,29 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
for batch in self.data_loader:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
# forward pass
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
latents.requires_grad_(True)
|
||||
# resize so it matches size of vae evenly
|
||||
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
||||
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
|
||||
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
||||
|
||||
# forward pass
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
latents.detach().requires_grad_(True)
|
||||
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
with torch.no_grad():
|
||||
show_tensors(
|
||||
pred.clamp(-1, 1).clone(),
|
||||
"combined tensor"
|
||||
)
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
@@ -441,14 +479,31 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
if self.lpips_weight > 0:
|
||||
lpips_loss = self.lpips_loss(
|
||||
pred.clamp(-1, 1),
|
||||
batch.clamp(-1, 1)
|
||||
).mean() * self.lpips_weight
|
||||
else:
|
||||
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
|
||||
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
|
||||
if self.lpips_weight > 0:
|
||||
max_target = lpips_loss.abs() * 0.1
|
||||
with torch.no_grad():
|
||||
crit_g_scaler = 1.0
|
||||
if critic_gen_loss.abs() > max_target:
|
||||
crit_g_scaler = max_target / critic_gen_loss.abs()
|
||||
|
||||
critic_gen_loss *= crit_g_scaler
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -460,6 +515,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_value = loss.item()
|
||||
# get exponent like 3.54e-4
|
||||
loss_string = f"loss: {loss_value:.2e}"
|
||||
if self.lpips_weight > 0:
|
||||
loss_string += f" lpips: {lpips_loss.item():.2e}"
|
||||
if self.content_weight > 0:
|
||||
loss_string += f" cnt: {content_loss.item():.2e}"
|
||||
if self.style_weight > 0:
|
||||
@@ -496,6 +553,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.progress_bar.update(1)
|
||||
|
||||
epoch_losses["total"].append(loss_value)
|
||||
epoch_losses["lpips"].append(lpips_loss.item())
|
||||
epoch_losses["style"].append(style_loss.item())
|
||||
epoch_losses["content"].append(content_loss.item())
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
@@ -506,6 +564,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["lpips"].append(lpips_loss.item())
|
||||
log_losses["style"].append(style_loss.item())
|
||||
log_losses["content"].append(content_loss.item())
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
|
||||
@@ -24,4 +24,5 @@ controlnet_aux==0.0.7
|
||||
python-dotenv
|
||||
bitsandbytes
|
||||
xformers
|
||||
hf_transfer
|
||||
hf_transfer
|
||||
lpips
|
||||
@@ -7,10 +7,11 @@ from torchvision import transforms
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import random
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import show_img
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
@@ -32,7 +33,7 @@ parser.add_argument('--epochs', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 1024
|
||||
resolution = 512
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
|
||||
@@ -41,6 +42,7 @@ batch_size = 1
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
control_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
# caption_ext='json',
|
||||
default_caption='default',
|
||||
@@ -48,62 +50,135 @@ dataset_config = DatasetConfig(
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
# poi='person',
|
||||
shuffle_augmentations=True,
|
||||
# shuffle_augmentations=True,
|
||||
# augmentations=[
|
||||
# {
|
||||
# 'method': 'GaussianBlur',
|
||||
# 'blur_limit': (1, 16),
|
||||
# 'sigma_limit': (0, 8),
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'ImageCompression',
|
||||
# 'quality_lower': 10,
|
||||
# 'quality_upper': 100,
|
||||
# 'compression_type': 0,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'ImageCompression',
|
||||
# 'quality_lower': 20,
|
||||
# 'quality_upper': 100,
|
||||
# 'compression_type': 1,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'RingingOvershoot',
|
||||
# 'blur_limit': (3, 35),
|
||||
# 'cutoff': (0.7, 1.96),
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'GaussNoise',
|
||||
# 'var_limit': (0, 300),
|
||||
# 'per_channel': True,
|
||||
# 'mean': 0.0,
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'GlassBlur',
|
||||
# 'sigma': 0.6,
|
||||
# 'max_delta': 7,
|
||||
# 'iterations': 2,
|
||||
# 'mode': 'fast',
|
||||
# 'p': 0.8
|
||||
# },
|
||||
# {
|
||||
# 'method': 'Downscale',
|
||||
# 'scale_max': 0.5,
|
||||
# 'interpolation': 'cv2.INTER_CUBIC',
|
||||
# 'p': 0.8
|
||||
# 'method': 'Posterize',
|
||||
# 'num_bits': [(0, 4), (0, 4), (0, 4)],
|
||||
# 'p': 1.0
|
||||
# },
|
||||
#
|
||||
# ]
|
||||
|
||||
|
||||
)
|
||||
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
|
||||
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
|
||||
if random.random() < p:
|
||||
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
||||
# make sure it is odd
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
|
||||
return img
|
||||
|
||||
def quantize(image, palette):
|
||||
"""
|
||||
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
|
||||
Only works for one image i.e. CHW. Does NOT work for batches.
|
||||
ref https://discuss.pytorch.org/t/color-quantization/104528/4
|
||||
"""
|
||||
|
||||
orig_dtype = image.dtype
|
||||
|
||||
C, H, W = image.shape
|
||||
n_colors = palette.shape[0]
|
||||
|
||||
# Easier to work with list of colors
|
||||
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
|
||||
|
||||
# Repeat image so that there are n_color number of columns of the same image
|
||||
flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C]
|
||||
|
||||
# Get euclidean distance between each pixel in each column and the column's respective color
|
||||
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
|
||||
squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2
|
||||
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]
|
||||
|
||||
# Get the shortest distance (one value per row (H*W) is selected)
|
||||
min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W]
|
||||
|
||||
# Create a mask for the closest colors
|
||||
one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors]
|
||||
|
||||
# Multiply the mask with the palette colors to get the quantized image
|
||||
quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C]
|
||||
|
||||
# Reshape it back to the original input format.
|
||||
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
|
||||
|
||||
return quantized_img.to(orig_dtype)
|
||||
|
||||
|
||||
|
||||
def color_block_imgs(img, neg1_1=False):
|
||||
# expects values 0 - 1
|
||||
orig_dtype = img.dtype
|
||||
if neg1_1:
|
||||
img = img * 0.5 + 0.5
|
||||
|
||||
img = img * 255
|
||||
img = img.clamp(0, 255)
|
||||
img = img.to(torch.uint8)
|
||||
|
||||
img_chunks = torch.chunk(img, img.shape[0], dim=0)
|
||||
|
||||
posterized_chunks = []
|
||||
|
||||
for chunk in img_chunks:
|
||||
img_size = (chunk.shape[2] + chunk.shape[3]) // 2
|
||||
# min kernel size of 1% of image, max 10%
|
||||
min_kernel_size = int(img_size * 0.01)
|
||||
max_kernel_size = int(img_size * 0.1)
|
||||
|
||||
# blur first
|
||||
chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8)
|
||||
num_colors = random.randint(1, 16)
|
||||
|
||||
resize_to = 16
|
||||
# chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use)
|
||||
|
||||
# mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))]
|
||||
|
||||
# shrink the image down to num_colors x num_colors
|
||||
shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to])
|
||||
|
||||
mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))]
|
||||
|
||||
colors = shrunk.view(3, -1).T
|
||||
# remove duplicates
|
||||
colors = torch.unique(colors, dim=0)
|
||||
colors = colors.numpy()
|
||||
colors = colors.tolist()
|
||||
|
||||
use_colors = [random.choice(colors) for _ in range(num_colors)]
|
||||
|
||||
pallette = torch.tensor([
|
||||
[0, 0, 0],
|
||||
mean_color,
|
||||
[255, 255, 255],
|
||||
] + use_colors, dtype=torch.float32)
|
||||
chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0)
|
||||
|
||||
# chunk = torchvision.transforms.functional.equalize(chunk)
|
||||
# color jitter
|
||||
if random.random() < 0.5:
|
||||
chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5))
|
||||
if random.random() < 0.5:
|
||||
chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0))
|
||||
# if random.random() < 0.5:
|
||||
# chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5))
|
||||
chunk = random_blur(chunk, p=0.6)
|
||||
posterized_chunks.append(chunk)
|
||||
|
||||
img = torch.cat(posterized_chunks, dim=0)
|
||||
img = img.to(orig_dtype)
|
||||
img = img / 255
|
||||
|
||||
if neg1_1:
|
||||
img = img * 2 - 1
|
||||
return img
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
@@ -112,11 +187,19 @@ for epoch in range(args.epochs):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0)
|
||||
big_control_img = torch.cat(control_chunks, dim=3)
|
||||
big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||
|
||||
big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
@@ -127,7 +210,7 @@ for epoch in range(args.epochs):
|
||||
|
||||
show_img(img)
|
||||
|
||||
# time.sleep(1.0)
|
||||
time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -129,13 +129,13 @@ class NetworkConfig:
|
||||
self.conv = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
@@ -530,6 +530,8 @@ class DatasetConfig:
|
||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
||||
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||
# apply same augmentations to control images. Usually want this true unless special case
|
||||
self.replay_transforms: bool = kwargs.get('replay_transforms', True)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -860,7 +860,8 @@ class AugmentationFileItemDTOMixin:
|
||||
# only store the spatial transforms
|
||||
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
|
||||
|
||||
self.aug_replay_spatial_transforms = augmented_params
|
||||
if self.dataset_config.replay_transforms:
|
||||
self.aug_replay_spatial_transforms = augmented_params
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
@@ -240,7 +240,7 @@ def get_direct_guidance_loss(
|
||||
|
||||
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
guidance_scale = 1.0
|
||||
guidance_scale = 1.25
|
||||
guidance_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
@@ -586,6 +586,8 @@ def get_guided_tnt(
|
||||
|
||||
loss = prior_loss + this_loss - that_loss
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from transformers import Adafactor
|
||||
from transformers import Adafactor, AdamW
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
@@ -69,7 +69,7 @@ def get_optimizer(
|
||||
if 'relative_step' not in optimizer_params:
|
||||
optimizer_params['relative_step'] = False
|
||||
if 'scale_parameter' not in optimizer_params:
|
||||
optimizer_params['scale_parameter'] = True
|
||||
optimizer_params['scale_parameter'] = False
|
||||
if 'warmup_init' not in optimizer_params:
|
||||
optimizer_params['warmup_init'] = False
|
||||
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
|
||||
@@ -39,7 +39,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -497,6 +498,12 @@ class StableDiffusion:
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
elif isinstance(self.adapter, ControlNetModel):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLControlNetPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionControlNetPipeline
|
||||
extra_args['controlnet'] = self.adapter
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass the noise scheduler to the adapter
|
||||
self.adapter.noise_scheduler = noise_scheduler
|
||||
@@ -588,6 +595,10 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, ControlNetModel):
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -967,6 +978,16 @@ class StableDiffusion:
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
# handle controlnet
|
||||
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
|
||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||
for idx, item in enumerate(kwargs['down_block_additional_residuals']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
for idx, item in enumerate(kwargs['mid_block_additional_residual']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
@@ -1383,11 +1404,13 @@ class StableDiffusion:
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
# resize images if not divisible by 8
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
if isinstance(self.vae, AutoencoderTiny):
|
||||
@@ -1756,6 +1779,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ControlNetModel):
|
||||
requires_grad = self.adapter.conv_in.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
|
||||
@@ -158,7 +158,7 @@ def get_style_model_and_losses(
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
|
||||
content_layers = ['conv2_2', 'conv3_2', 'conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
|
||||
@@ -81,7 +81,8 @@ def step_adafactor(self, closure=None):
|
||||
lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
update = (grad ** 2) + group["eps"][0]
|
||||
eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"]
|
||||
update = (grad ** 2) + eps
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
Reference in New Issue
Block a user