mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 15:39:57 +00:00
347 lines
13 KiB
Python
347 lines
13 KiB
Python
import copy
|
|
import os
|
|
import time
|
|
from collections import OrderedDict
|
|
|
|
from PIL import Image
|
|
from PIL.ImageOps import exif_transpose
|
|
from safetensors.torch import save_file
|
|
from torch.utils.data import DataLoader, ConcatDataset
|
|
import torch
|
|
from torch import nn
|
|
from torchvision.transforms import transforms
|
|
|
|
from jobs.process import BaseTrainProcess
|
|
from toolkit.kohya_model_util import load_vae
|
|
from toolkit.data_loader import ImageDataset
|
|
from toolkit.metadata import get_meta_for_safetensors
|
|
from toolkit.style import get_style_model_and_losses
|
|
from toolkit.train_tools import get_torch_dtype
|
|
from diffusers import AutoencoderKL
|
|
from tqdm import tqdm
|
|
import time
|
|
import numpy as np
|
|
|
|
IMAGE_TRANSFORMS = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]),
|
|
]
|
|
)
|
|
|
|
|
|
def unnormalize(tensor):
|
|
return (tensor / 2 + 0.5).clamp(0, 1)
|
|
|
|
|
|
class TrainVAEProcess(BaseTrainProcess):
|
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
|
super().__init__(process_id, job, config)
|
|
self.data_loader = None
|
|
self.vae = None
|
|
self.device = self.get_conf('device', self.job.device)
|
|
self.vae_path = self.get_conf('vae_path', required=True)
|
|
self.datasets_objects = self.get_conf('datasets', required=True)
|
|
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
|
self.batch_size = self.get_conf('batch_size', 1)
|
|
self.resolution = self.get_conf('resolution', 256)
|
|
self.learning_rate = self.get_conf('learning_rate', 1e-4)
|
|
self.sample_every = self.get_conf('sample_every', None)
|
|
self.epochs = self.get_conf('epochs', None)
|
|
self.max_steps = self.get_conf('max_steps', None)
|
|
self.save_every = self.get_conf('save_every', None)
|
|
self.dtype = self.get_conf('dtype', 'float32')
|
|
self.sample_sources = self.get_conf('sample_sources', None)
|
|
self.style_weight = self.get_conf('style_weight', 1e4)
|
|
self.content_weight = self.get_conf('content_weight', 1)
|
|
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
|
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
|
self.save_root = os.path.join(self.training_folder, self.job.name)
|
|
self.vgg_19 = None
|
|
|
|
if self.sample_every is not None and self.sample_sources is None:
|
|
raise ValueError('sample_every is specified but sample_sources is not')
|
|
|
|
if self.epochs is None and self.max_steps is None:
|
|
raise ValueError('epochs or max_steps must be specified')
|
|
|
|
self.data_loaders = []
|
|
# check datasets
|
|
assert isinstance(self.datasets_objects, list)
|
|
for dataset in self.datasets_objects:
|
|
if 'path' not in dataset:
|
|
raise ValueError('dataset must have a path')
|
|
# check if is dir
|
|
if not os.path.isdir(dataset['path']):
|
|
raise ValueError(f"dataset path does is not a directory: {dataset['path']}")
|
|
|
|
# make training folder
|
|
if not os.path.exists(self.save_root):
|
|
os.makedirs(self.save_root, exist_ok=True)
|
|
|
|
def load_datasets(self):
|
|
if self.data_loader is None:
|
|
print(f"Loading datasets")
|
|
datasets = []
|
|
for dataset in self.datasets_objects:
|
|
print(f" - Dataset: {dataset['path']}")
|
|
ds = copy.copy(dataset)
|
|
ds['resolution'] = self.resolution
|
|
image_dataset = ImageDataset(ds)
|
|
datasets.append(image_dataset)
|
|
|
|
concatenated_dataset = ConcatDataset(datasets)
|
|
self.data_loader = DataLoader(
|
|
concatenated_dataset,
|
|
batch_size=self.batch_size,
|
|
shuffle=True,
|
|
num_workers=6
|
|
)
|
|
|
|
def setup_vgg19(self):
|
|
if self.vgg_19 is None:
|
|
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
|
|
single_target=True, device=self.device)
|
|
self.vgg_19.requires_grad_(False)
|
|
|
|
def get_mse_loss(self, pred, target):
|
|
loss_fn = nn.MSELoss()
|
|
loss = loss_fn(pred, target)
|
|
return loss
|
|
|
|
def get_elbo_loss(self, pred, target, mu, log_var):
|
|
# ELBO (Evidence Lower BOund) loss, aka variational lower bound
|
|
reconstruction_loss = nn.MSELoss(reduction='sum')
|
|
BCE = reconstruction_loss(pred, target) # reconstruction loss
|
|
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence
|
|
return BCE + KLD
|
|
|
|
def save(self, step=None):
|
|
if not os.path.exists(self.save_root):
|
|
os.makedirs(self.save_root, exist_ok=True)
|
|
|
|
step_num = ''
|
|
if step is not None:
|
|
# zeropad 9 digits
|
|
step_num = f"_{str(step).zfill(9)}"
|
|
|
|
filename = f'{self.job.name}{step_num}.safetensors'
|
|
save_path = os.path.join(self.save_root, filename)
|
|
# prepare meta
|
|
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
|
|
|
state_dict = self.vae.state_dict()
|
|
|
|
for key in list(state_dict.keys()):
|
|
v = state_dict[key]
|
|
v = v.detach().clone().to("cpu").to(torch.float32)
|
|
state_dict[key] = v
|
|
|
|
# having issues with meta
|
|
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
|
|
|
|
print(f"Saved to {os.path.join(self.save_root, filename)}")
|
|
|
|
def sample(self, step=None):
|
|
sample_folder = os.path.join(self.save_root, 'samples')
|
|
if not os.path.exists(sample_folder):
|
|
os.makedirs(sample_folder, exist_ok=True)
|
|
|
|
with torch.no_grad():
|
|
self.vae.encoder.eval()
|
|
self.vae.decoder.eval()
|
|
|
|
for i, img_url in enumerate(self.sample_sources):
|
|
img = exif_transpose(Image.open(img_url))
|
|
img = img.convert('RGB')
|
|
# crop if not square
|
|
if img.width != img.height:
|
|
min_dim = min(img.width, img.height)
|
|
img = img.crop((0, 0, min_dim, min_dim))
|
|
# resize
|
|
img = img.resize((self.resolution, self.resolution))
|
|
|
|
input_img = img
|
|
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
|
img = img
|
|
decoded = self.vae(img).sample
|
|
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
|
|
|
#convert to pillow image
|
|
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
|
|
|
# # decoded = decoded - 0.1
|
|
# decoded = decoded
|
|
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
|
|
|
# stack input image and decoded image
|
|
input_img = input_img.resize((self.resolution, self.resolution))
|
|
decoded = decoded.resize((self.resolution, self.resolution))
|
|
|
|
output_img = Image.new('RGB', (self.resolution * 2, self.resolution))
|
|
output_img.paste(input_img, (0, 0))
|
|
output_img.paste(decoded, (self.resolution, 0))
|
|
|
|
step_num = ''
|
|
if step is not None:
|
|
# zeropad 9 digits
|
|
step_num = f"_{str(step).zfill(9)}"
|
|
seconds_since_epoch = int(time.time())
|
|
# zeropad 2 digits
|
|
i_str = str(i).zfill(2)
|
|
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
|
output_img.save(os.path.join(sample_folder, filename))
|
|
|
|
def run(self):
|
|
super().run()
|
|
self.load_datasets()
|
|
|
|
max_step_epochs = self.max_steps // len(self.data_loader)
|
|
num_epochs = self.epochs
|
|
if num_epochs is None or num_epochs > max_step_epochs:
|
|
num_epochs = max_step_epochs
|
|
|
|
max_epoch_steps = len(self.data_loader) * num_epochs
|
|
num_steps = self.max_steps
|
|
if num_steps is None or num_steps > max_epoch_steps:
|
|
num_steps = max_epoch_steps
|
|
|
|
print(f"Training VAE")
|
|
print(f" - Training folder: {self.training_folder}")
|
|
print(f" - Batch size: {self.batch_size}")
|
|
print(f" - Learning rate: {self.learning_rate}")
|
|
print(f" - Epochs: {num_epochs}")
|
|
print(f" - Max steps: {self.max_steps}")
|
|
|
|
# load vae
|
|
print(f"Loading VAE")
|
|
print(f" - Loading VAE: {self.vae_path}")
|
|
if self.vae is None:
|
|
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
|
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
|
|
|
# set decoder to train
|
|
self.vae.to(self.device, dtype=self.torch_dtype)
|
|
self.vae.requires_grad_(False)
|
|
self.vae.eval()
|
|
self.vae.decoder.train()
|
|
|
|
blocks_to_train = [
|
|
'mid_block',
|
|
'up_blocks',
|
|
]
|
|
|
|
params = []
|
|
|
|
# only set last 2 layers to trainable
|
|
for param in self.vae.decoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
if 'mid_block' in blocks_to_train:
|
|
params += list(self.vae.decoder.mid_block.parameters())
|
|
self.vae.decoder.mid_block.requires_grad_(True)
|
|
if 'up_blocks' in blocks_to_train:
|
|
params += list(self.vae.decoder.up_blocks.parameters())
|
|
self.vae.decoder.up_blocks.requires_grad_(True)
|
|
|
|
# self.vae.decoder.train()
|
|
|
|
self.setup_vgg19()
|
|
self.vgg_19.requires_grad_(False)
|
|
self.vgg_19.eval()
|
|
|
|
|
|
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
|
|
|
# setup scheduler
|
|
# scheduler = lr_scheduler.ConstantLR
|
|
# todo allow other schedulers
|
|
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
|
optimizer,
|
|
total_iters=num_steps,
|
|
factor=1,
|
|
verbose=False
|
|
)
|
|
|
|
# setup tqdm progress bar
|
|
progress_bar = tqdm(
|
|
total=num_steps,
|
|
desc='Training VAE',
|
|
leave=True
|
|
)
|
|
|
|
step = 0
|
|
# sample first
|
|
self.sample()
|
|
for epoch in range(num_epochs):
|
|
if step >= num_steps:
|
|
break
|
|
for batch in self.data_loader:
|
|
if step >= num_steps:
|
|
break
|
|
|
|
batch = batch.to(self.device, dtype=self.torch_dtype)
|
|
|
|
# forward pass
|
|
# with torch.no_grad():
|
|
# batch = batch + 0.1
|
|
dgd = self.vae.encode(batch).latent_dist
|
|
mu, logvar = dgd.mean, dgd.logvar
|
|
latents = dgd.sample()
|
|
latents.requires_grad_(True)
|
|
|
|
pred = self.vae.decode(latents).sample
|
|
|
|
# pred = pred + 0.1
|
|
|
|
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
|
|
|
stacked = torch.cat([pred, batch], dim=0)
|
|
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
|
self.vgg_19(stacked)
|
|
# reduce the mean of the style_loss list
|
|
|
|
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
|
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
|
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
|
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
|
|
style_loss = style_loss * self.style_weight
|
|
content_loss = content_loss * self.content_weight
|
|
elbo_loss = elbo_loss * self.elbo_weight
|
|
|
|
loss = style_loss + content_loss + elbo_loss
|
|
|
|
# Backward pass and optimization
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
# update progress bar
|
|
loss_value = loss.item()
|
|
# get exponent like 3.54e-4
|
|
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
|
|
learning_rate = optimizer.param_groups[0]['lr']
|
|
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
|
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
|
progress_bar.update(1)
|
|
|
|
if step != 0:
|
|
if self.sample_every and step % self.sample_every == 0:
|
|
# print above the progress bar
|
|
print(f"Sampling at step {step}")
|
|
self.sample(step)
|
|
|
|
if self.save_every and step % self.save_every == 0:
|
|
# print above the progress bar
|
|
print(f"Saving at step {step}")
|
|
self.save(step)
|
|
|
|
step += 1
|
|
|
|
self.save()
|
|
|
|
pass
|