diff --git a/.gitmodules b/.gitmodules index 9828d447..18079356 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "repositories/sd-scripts"] path = repositories/sd-scripts url = https://github.com/kohya-ss/sd-scripts.git +[submodule "repositories/leco"] + path = repositories/leco + url = https://github.com/p1atdev/LECO diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 4137200b..8a0ad3bf 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -1,85 +1,40 @@ -# from jobs import BaseJob -# from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint -# from collections import OrderedDict -# from typing import List -# from jobs.process import BaseExtractProcess, TrainFineTuneProcess -# import gc -# import time -# import argparse -# import itertools -# import math -# import os -# from multiprocessing import Value -# -# from tqdm import tqdm -# import torch -# from accelerate.utils import set_seed -# from accelerate import Accelerator -# import diffusers -# from diffusers import DDPMScheduler -# -# from toolkit.paths import SD_SCRIPTS_ROOT -# -# import sys -# -# sys.path.append(SD_SCRIPTS_ROOT) -# -# import library.train_util as train_util -# import library.config_util as config_util -# from library.config_util import ( -# ConfigSanitizer, -# BlueprintGenerator, -# ) -# import toolkit.train_tools as train_tools -# import library.custom_train_functions as custom_train_functions -# from library.custom_train_functions import ( -# apply_snr_weight, -# get_weighted_text_embeddings, -# prepare_scheduler_for_custom_training, -# pyramid_noise_like, -# apply_noise_offset, -# scale_v_prediction_loss_like_noise_prediction, -# ) -# -# process_dict = { -# 'fine_tine': 'TrainFineTuneProcess' -# } -# -# -# class TrainJob(BaseJob): -# process: List[BaseExtractProcess] -# -# def __init__(self, config: OrderedDict): -# super().__init__(config) -# self.base_model_path = self.get_conf('base_model', required=True) -# self.base_model = None -# self.training_folder = self.get_conf('training_folder', required=True) -# self.is_v2 = self.get_conf('is_v2', False) -# self.device = self.get_conf('device', 'cpu') -# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) -# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 -# self.logging_dir = self.get_conf('logging_dir', None) -# -# # loads the processes from the config -# self.load_processes(process_dict) -# -# # setup accelerator -# self.accelerator = Accelerator( -# gradient_accumulation_steps=self.gradient_accumulation_steps, -# mixed_precision=self.mixed_precision, -# log_with=None if self.logging_dir is None else 'tensorboard', -# logging_dir=self.logging_dir, -# ) -# -# def run(self): -# super().run() -# # load models -# print(f"Loading base model for training") -# print(f" - Loading base model: {self.base_model_path}") -# self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) -# -# print("") -# print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") -# -# for process in self.process: -# process.run() +from jobs import BaseJob +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from typing import List +from jobs.process import BaseExtractProcess, TrainFineTuneProcess + +from toolkit.paths import REPOS_ROOT + +import sys + +sys.path.append(REPOS_ROOT) + +process_dict = { + 'vae': 'TrainVAEProcess', + 'finetune': 'TrainFineTuneProcess' +} + + +class TrainJob(BaseJob): + process: List[BaseExtractProcess] + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.training_folder = self.get_conf('training_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) + self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 + self.logging_dir = self.get_conf('logging_dir', None) + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index 09be1770..9c2472f6 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -1,2 +1,3 @@ from .BaseJob import BaseJob from .ExtractJob import ExtractJob +from .TrainJob import TrainJob diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index cac0335e..afcf9399 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -13,9 +13,6 @@ class BaseTrainProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) - self.process_id = process_id - self.job = job - self.config = config def run(self): # implement in child class diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py new file mode 100644 index 00000000..e30e37a3 --- /dev/null +++ b/jobs/process/TrainVAEProcess.py @@ -0,0 +1,291 @@ +import copy +import os +import time +from collections import OrderedDict + +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import save_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.kohya_model_util import load_vae +from toolkit.data_loader import ImageDataset +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype +from tqdm import tqdm +import time +import numpy as np + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + +INVERSE_IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.Normalize( + mean=[-0.5/0.5], + std=[1/0.5] + ), + transforms.ToPILImage(), + ] +) + + +class TrainVAEProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.vae = None + self.device = self.get_conf('device', self.job.device) + self.vae_path = self.get_conf('vae_path', required=True) + self.datasets_objects = self.get_conf('datasets', required=True) + self.training_folder = self.get_conf('training_folder', self.job.training_folder) + self.batch_size = self.get_conf('batch_size', 1) + self.resolution = self.get_conf('resolution', 256) + self.learning_rate = self.get_conf('learning_rate', 1e-4) + self.sample_every = self.get_conf('sample_every', None) + self.epochs = self.get_conf('epochs', None) + self.max_steps = self.get_conf('max_steps', None) + self.save_every = self.get_conf('save_every', None) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.torch_dtype = get_torch_dtype(self.dtype) + self.save_root = os.path.join(self.training_folder, self.job.name) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + datasets = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + ds['resolution'] = self.resolution + image_dataset = ImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True + ) + + def get_loss(self, pred, target): + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + + def get_elbo_loss(self, pred, target, mu, log_var): + # ELBO (Evidence Lower BOund) loss, aka variational lower bound + reconstruction_loss = nn.MSELoss(reduction='sum') + BCE = reconstruction_loss(pred, target) # reconstruction loss + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence + return BCE + KLD + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + filename = f'{self.job.name}{step_num}.safetensors' + save_path = os.path.join(self.save_root, filename) + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + state_dict = self.vae.state_dict() + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(torch.float32) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, os.path.join(self.save_root, filename), save_meta) + + print(f"Saved to {os.path.join(self.save_root, filename)}") + + def sample(self, step=None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + + with torch.no_grad(): + self.vae.encoder.eval() + self.vae.decoder.eval() + + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.resolution, self.resolution)) + + input_img = img + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) + decoded = self.vae(img).sample.squeeze(0) + decoded = INVERSE_IMAGE_TRANSFORMS(decoded) + + # stack input image and decoded image + input_img = input_img.resize((self.resolution, self.resolution)) + decoded = decoded.resize((self.resolution, self.resolution)) + + output_img = Image.new('RGB', (self.resolution * 2, self.resolution)) + output_img.paste(input_img, (0, 0)) + output_img.paste(decoded, (self.resolution, 0)) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zeropad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + output_img.save(os.path.join(sample_folder, filename)) + self.vae.decoder.train() + + def run(self): + super().run() + self.load_datasets() + + max_step_epochs = self.max_steps // len(self.data_loader) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + + print(f"Training VAE") + print(f" - Training folder: {self.training_folder}") + print(f" - Batch size: {self.batch_size}") + print(f" - Learning rate: {self.learning_rate}") + print(f" - Epochs: {num_epochs}") + print(f" - Max steps: {self.max_steps}") + + # load vae + print(f"Loading VAE") + print(f" - Loading VAE: {self.vae_path}") + if self.vae is None: + self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) + + # set decoder to train + self.vae.to(self.device, dtype=self.torch_dtype) + self.vae.requires_grad_(False) + self.vae.eval() + + self.vae.decoder.requires_grad_(True) + self.vae.decoder.train() + + parameters = self.vae.decoder.parameters() + + optimizer = torch.optim.Adam(parameters, lr=self.learning_rate) + + # setup scheduler + # scheduler = lr_scheduler.ConstantLR + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + verbose=False + ) + + # setup tqdm progress bar + progress_bar = tqdm( + total=num_steps, + desc='Training VAE', + leave=True + ) + + step = 0 + # sample first + self.sample() + for epoch in range(num_epochs): + if step >= num_steps: + break + for batch in self.data_loader: + if step >= num_steps: + break + + batch = batch.to(self.device, dtype=self.torch_dtype) + + # forward pass + # with torch.no_grad(): + dgd = self.vae.encode(batch).latent_dist + mu, logvar = dgd.mean, dgd.logvar + latents = dgd.sample() + latents.requires_grad_(True) + + pred = self.vae.decode(latents).sample + + loss = self.get_elbo_loss(pred, batch, mu, logvar) + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"{loss_value:.2e}" + learning_rate = optimizer.param_groups[0]['lr'] + progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}") + progress_bar.set_description(f"E: {epoch} - S: {step} ") + progress_bar.update(1) + + if step != 0: + if self.sample_every and step % self.sample_every == 0: + # print above the progress bar + print(f"Sampling at step {step}") + self.sample(step) + + if self.save_every and step % self.save_every == 0: + # print above the progress bar + print(f"Saving at step {step}") + self.save(step) + + step += 1 + + self.save() + + pass diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 1e307d53..6160dff7 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -3,4 +3,4 @@ from .ExtractLoconProcess import ExtractLoconProcess from .ExtractLoraProcess import ExtractLoraProcess from .BaseProcess import BaseProcess from .BaseTrainProcess import BaseTrainProcess -from .TrainFineTuneProcess import TrainFineTuneProcess +from .TrainVAEProcess import TrainVAEProcess diff --git a/repositories/leco b/repositories/leco new file mode 160000 index 00000000..9294adf4 --- /dev/null +++ b/repositories/leco @@ -0,0 +1 @@ +Subproject commit 9294adf40218e917df4516737afb13f069a6789d diff --git a/requirements.txt b/requirements.txt index b2c6fe65..f5355be2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch +torchvision safetensors diffusers transformers diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py new file mode 100644 index 00000000..153888fe --- /dev/null +++ b/toolkit/data_loader.py @@ -0,0 +1,67 @@ +import os +import random +from PIL import Image +from PIL.ImageOps import exif_transpose +from torchvision import transforms +from torch.utils.data import Dataset + + +class ImageDataset(Dataset): + def __init__(self, config): + self.config = config + self.name = self.get_config('name', 'dataset') + self.path = self.get_config('path', required=True) + self.scale = self.get_config('scale', 1) + self.random_scale = self.get_config('random_scale', False) + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) + + self.resolution = self.get_config('resolution', 256) + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + + # this might take a while + print(f" - Preprocessing image dimensions") + self.file_list = [file for file in self.file_list if + int(min(Image.open(file).size) * self.scale) >= self.resolution] + + print(f" - Found {len(self.file_list)} images") + assert len(self.file_list) > 0, f"no images found in {self.path}" + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + img_path = self.file_list[index] + img = exif_transpose(Image.open(img_path)).convert('RGB') + + # Downscale the source image first + img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + + if self.random_crop: + if self.random_scale: + scale_size = random.randint(int(img.size[0] * self.scale), self.resolution) + img = img.resize((scale_size, scale_size), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + min_dimension = min(img.size) + img = transforms.CenterCrop(min_dimension)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + + img = self.transform(img) + + return img diff --git a/toolkit/job.py b/toolkit/job.py index 2991cc5a..62d1bcfd 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -10,6 +10,10 @@ def get_job(config_path): if job == 'extract': from jobs import ExtractJob return ExtractJob(config) + if job == 'train': + from jobs import TrainJob + return TrainJob(config) + # elif job == 'train': # from jobs import TrainJob # return TrainJob(config) diff --git a/toolkit/paths.py b/toolkit/paths.py index d6ee1fe4..2eb92e06 100644 --- a/toolkit/paths.py +++ b/toolkit/paths.py @@ -3,3 +3,4 @@ import os TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts") +REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")