diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py new file mode 100644 index 00000000..e9f7d0ef --- /dev/null +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -0,0 +1,202 @@ +import copy +import random +from collections import OrderedDict +import os +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ReferenceSliderConfig: + def __init__(self, **kwargs): + self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None) + self.resolutions: List[int] = kwargs.get('resolutions', [512]) + self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) + self.target_class: int = kwargs.get('target_class', '') + + +class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {})) + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for res in self.slider_config.resolutions: + print(f" - Dataset: {self.slider_config.slider_pair_folder}") + config = { + 'path': self.slider_config.slider_pair_folder, + 'size': res, + 'default_prompt': self.slider_config.target_class + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + self.load_datasets() + + pass + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts = batch + dtype = get_torch_dtype(self.train_config.dtype) + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + # encode the images + positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample() + positive_latents = positive_latents * 0.18215 + negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample() + negative_latents = negative_latents * 0.18215 + + embedding_list = [] + negative_embedding_list = [] + # embed the prompts + for prompt in prompts: + embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + # just empty for now + # todo cache this? + negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype) + negative_embedding_list.append(negative_embed) + + conditional_embeds = concat_prompt_embeds(embedding_list) + unconditional_embeds = concat_prompt_embeds(negative_embedding_list) + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + + def get_noise_pred(neg, pos, gs, cts, dn): + return self.sd.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + neg, # negative prompt + pos, # positive prompt + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + ) + + with torch.no_grad(): + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps) + + flush() + + self.optimizer.zero_grad() + with self.network: + assert self.network.is_active + loss_list = [] + for noisy_latents, network_multiplier in zip( + [noisy_positive_latents, noisy_negative_latents], + [1.0, -1.0], + ): + # do positive first + self.network.multiplier = network_multiplier + + noise_pred = get_noise_pred( + unconditional_embeds, + conditional_embeds, + 1, + timesteps, + noisy_latents + ) + + if self.sd.is_v2: # check is vpred, don't want to track it down right now + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + # todo add snr gamma here + + loss = loss.mean() + # back propagate loss to free ram + loss.backward() + loss_list.append(loss.item()) + + flush() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + loss_float = sum(loss_list) / len(loss_list) + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/image_reference_slider_trainer/__init__.py b/extensions_built_in/image_reference_slider_trainer/__init__.py new file mode 100644 index 00000000..8a15f646 --- /dev/null +++ b/extensions_built_in/image_reference_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ImageReferenceSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "image_reference_slider_trainer" + + # name is the name of the extension for printing + name = "Image Reference Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess + return ImageReferenceSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ImageReferenceSliderTrainer +] diff --git a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml new file mode 100644 index 00000000..e69de29b diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e50ab66e..0606b9de 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,6 +1,9 @@ import glob from collections import OrderedDict import os +from typing import Union + +from torch.utils.data import DataLoader from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -54,6 +57,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None + self.data_loader: Union[DataLoader, None] = None self.sd = StableDiffusion( device=self.device, @@ -193,7 +197,7 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass - def hook_train_loop(self): + def hook_train_loop(self, batch=None): # return loss return 0.0 @@ -358,12 +362,29 @@ class BaseSDTrainProcess(BaseTrainProcess): iterable=range(0, self.train_config.steps), ) + if self.data_loader is not None: + dataloader = self.data_loader + dataloader_iterator = iter(dataloader) + else: + dataloader = None + dataloader_iterator = None + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): - # todo handle dataloader here maybe, not sure + if dataloader is not None: + try: + batch = next(dataloader_iterator) + except StopIteration: + # hit the end of an epoch, reset + # todo, should we do something else here? like blow up balloons? + dataloader_iterator = iter(dataloader) + batch = next(dataloader_iterator) + else: + batch = None ### HOOK ### - loss_dict = self.hook_train_loop() + loss_dict = self.hook_train_loop(batch) + flush() if self.train_config.optimizer.lower().startswith('dadaptation') or \ self.train_config.optimizer.lower().startswith('prodigy'): diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index cd6f8619..d1c65bf7 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -29,11 +29,11 @@ class BaseTrainProcess(BaseProcess): super().__init__(process_id, job, config) self.progress_bar = None self.writer = None - self.training_folder = self.get_conf('training_folder', self.job.training_folder) - self.save_root = os.path.join(self.training_folder, self.job.name) + self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None) + self.save_root = os.path.join(self.training_folder, self.name) self.step = 0 self.first_step = 0 - self.log_dir = self.get_conf('log_dir', self.job.log_dir) + self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None) self.setup_tensorboard() self.save_training_config() @@ -62,7 +62,7 @@ class BaseTrainProcess(BaseProcess): def save_training_config(self): timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') - os.makedirs(self.training_folder, exist_ok=True) - save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml') + os.makedirs(self.save_root, exist_ok=True) + save_dif = os.path.join(self.save_root, f'process_config_{timestamp}.yaml') with open(save_dif, 'w') as f: yaml.dump(self.raw_process_config, f) diff --git a/jobs/process/TrainLoRAHack.py b/jobs/process/TrainLoRAHack.py index a3fb118d..2a5a6539 100644 --- a/jobs/process/TrainLoRAHack.py +++ b/jobs/process/TrainLoRAHack.py @@ -68,7 +68,7 @@ class TrainLoRAHack(BaseSDTrainProcess): return loss_dict - def hook_train_loop(self): + def hook_train_loop(self, batch): if self.hack_config.type == 'suppression': return self.supress_loop() else: diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index d7464d5c..cc2dc339 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -210,7 +210,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): flush() # end hook_before_train_loop - def hook_train_loop(self): + def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) loss_function = torch.nn.MSELoss() diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index a79e67dc..46cef6b7 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -173,7 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess): flush() # end hook_before_train_loop - def hook_train_loop(self): + def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) # get a random pair diff --git a/jobs/process/TrainSliderProcessOld.py b/jobs/process/TrainSliderProcessOld.py index 9a673c46..8c25393a 100644 --- a/jobs/process/TrainSliderProcessOld.py +++ b/jobs/process/TrainSliderProcessOld.py @@ -221,7 +221,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess): flush() # end hook_before_train_loop - def hook_train_loop(self): + def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) # get a random pair diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index f731db1f..766c6b11 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -13,3 +13,4 @@ from .ModRescaleLoraProcess import ModRescaleLoraProcess from .GenerateProcess import GenerateProcess from .BaseExtensionProcess import BaseExtensionProcess from .TrainESRGANProcess import TrainESRGANProcess +from .BaseSDTrainProcess import BaseSDTrainProcess diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 75ece7c9..1e2264ac 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -140,3 +140,65 @@ class AugmentedImageDataset(ImageDataset): # return both # return image as 0 - 1 tensor return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) + + +class PairedImageDataset(Dataset): + def __init__(self, config): + super().__init__() + self.config = config + self.size = self.get_config('size', 512) + self.path = self.get_config('path', required=True) + self.default_prompt = self.get_config('default_prompt', '') + self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + print(f" - Found {len(self.file_list)} images") + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] + ]) + + def __len__(self): + return len(self.file_list) + + def get_config(self, key, default=None, required=False): + if key in self.config: + value = self.config[key] + return value + elif required: + raise ValueError(f'config file error. Missing "config.dataset.{key}" key') + else: + return default + + def __getitem__(self, index): + img_path = self.file_list[index] + img = exif_transpose(Image.open(img_path)).convert('RGB') + + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + '.txt' + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = self.default_prompt + + height = self.size + # determine width to keep aspect ratio + width = int(img.size[0] * height / img.size[1]) + + # Downscale the source image first + img = img.resize((width, height), Image.BICUBIC) + img = self.transform(img) + + return img, prompt + diff --git a/toolkit/extension.py b/toolkit/extension.py index cb47329d..8d1f38e5 100644 --- a/toolkit/extension.py +++ b/toolkit/extension.py @@ -25,25 +25,26 @@ class Extension(object): def get_all_extensions() -> List[Extension]: - # Get the path of the "extensions" directory - extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions") + extension_folders = ['extensions', 'extensions_built_in'] # This will hold the classes from all extension modules all_extension_classes: List[Extension] = [] # Iterate over all directories (i.e., packages) in the "extensions" directory - for (_, name, _) in pkgutil.iter_modules([extensions_dir]): - try: - # Import the module - module = importlib.import_module(f"extensions.{name}") - # Get the value of the AI_TOOLKIT_EXTENSIONS variable - extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) - # Check if the value is a list - if isinstance(extensions, list): - # Iterate over the list and add the classes to the main list - all_extension_classes.extend(extensions) - except ImportError as e: - print(f"Failed to import the {name} module. Error: {str(e)}") + for sub_dir in extension_folders: + extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + try: + # Import the module + module = importlib.import_module(f"{sub_dir}.{name}") + # Get the value of the AI_TOOLKIT_EXTENSIONS variable + extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) + # Check if the value is a list + if isinstance(extensions, list): + # Iterate over the list and add the classes to the main list + all_extension_classes.extend(extensions) + except ImportError as e: + print(f"Failed to import the {name} module. Error: {str(e)}") return all_extension_classes