From b77b9acc0bf5ae890fb52c5b3c9837f7a834d1a4 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 19 Aug 2023 15:35:24 -0600 Subject: [PATCH] Added base for ultimate slider. WIP --- .../ImageReferenceSliderTrainerProcess.py | 23 +- .../UltimateSliderTrainerProcess.py | 385 ++++++++++++++++++ .../ultimate_slider_trainer/__init__.py | 25 ++ .../config/train.example.yaml | 107 +++++ toolkit/config_modules.py | 19 +- toolkit/data_loader.py | 45 +- 6 files changed, 568 insertions(+), 36 deletions(-) create mode 100644 extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py create mode 100644 extensions_built_in/ultimate_slider_trainer/__init__.py create mode 100644 extensions_built_in/ultimate_slider_trainer/config/train.example.yaml diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index 688e172b..3036d8ef 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -5,6 +5,8 @@ import os from contextlib import nullcontext from typing import Optional, Union, List from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig from toolkit.data_loader import PairedImageDataset from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds @@ -21,29 +23,11 @@ def flush(): gc.collect() -class DatasetConfig: - def __init__(self, **kwargs): - # can pass with a side by side pait or a folder with pos and neg folder - self.pair_folder: str = kwargs.get('pair_folder', None) - self.pos_folder: str = kwargs.get('pos_folder', None) - self.neg_folder: str = kwargs.get('neg_folder', None) - - self.network_weight: float = float(kwargs.get('network_weight', 1.0)) - self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight)) - self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight)) - # make sure they are all absolute values no negatives - self.pos_weight = abs(self.pos_weight) - self.neg_weight = abs(self.neg_weight) - - self.target_class: str = kwargs.get('target_class', '') - self.size: int = kwargs.get('size', 512) - - class ReferenceSliderConfig: def __init__(self, **kwargs): self.additional_losses: List[str] = kwargs.get('additional_losses', []) self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) - self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])] + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): @@ -236,7 +220,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss.backward() flush() - # apply gradients optimizer.step() lr_scheduler.step() diff --git a/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py new file mode 100644 index 00000000..a6bb2f46 --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py @@ -0,0 +1,385 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random + +import random +from collections import OrderedDict +from tqdm import tqdm + +from toolkit.config_modules import SliderConfig +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +from toolkit import train_tools +from toolkit.prompt_utils import \ + EncodedPromptPair, ACTION_TYPES_SLIDER, \ + EncodedAnchor, concat_prompt_pairs, \ + concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \ + split_prompt_pairs + +import torch + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class UltimateSliderConfig(SliderConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] + + +class UltimateSliderTrainerProcess(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.prompt_txt_list = None + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = UltimateSliderConfig(**self.get_conf('slider', {})) + + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + # keep track of prompt chunk size + self.prompt_chunk_size = 1 + + # store a list of all the prompts from the dataset so we can cache it + self.dataset_prompts = [] + self.train_with_dataset = self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0 + + def load_datasets(self): + if self.data_loader is None and \ + self.slider_config.datasets is not None and len(self.slider_config.datasets) > 0: + print(f"Loading datasets") + datasets = [] + for dataset in self.slider_config.datasets: + print(f" - Dataset: {dataset.pair_folder}") + config = { + 'path': dataset.pair_folder, + 'size': dataset.size, + 'default_prompt': dataset.target_class, + 'network_weight': dataset.network_weight, + 'pos_weight': dataset.pos_weight, + 'neg_weight': dataset.neg_weight, + 'pos_folder': dataset.pos_folder, + 'neg_folder': dataset.neg_folder, + } + image_dataset = PairedImageDataset(config) + datasets.append(image_dataset) + + # capture all the prompts from it so we can cache the embeds + self.dataset_prompts += image_dataset.get_all_prompts() + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.train_config.batch_size, + shuffle=True, + num_workers=2 + ) + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + # load any datasets if they were passed + self.load_datasets() + + # read line by line from file + if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Found {len(self.prompt_txt_list)} prompts.") + + if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + cache = PromptEmbedsCache() + + # get encoded latents for our prompts + with torch.no_grad(): + # list of neutrals. Can come from file or be empty + neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] + + # build the prompts to cache + prompts_to_cache = [] + for neutral in neutral_list: + for target in self.slider_config.targets: + prompt_list = [ + f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets reverse + ] + prompts_to_cache += prompt_list + + # remove duplicates + prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + + # trim to max steps if max steps is lower than prompt count + prompts_to_cache = prompts_to_cache[:self.train_config.steps] + + if len(self.dataset_prompts) > 0: + # add the prompts from the dataset + prompts_to_cache += self.dataset_prompts + + # encode them + cache = encode_prompts_to_cache( + prompt_list=prompts_to_cache, + sd=self.sd, + cache=cache, + prompt_tensor_file=self.slider_config.prompt_tensors + ) + + prompt_pairs = [] + prompt_batches = [] + for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False): + for target in self.slider_config.targets: + prompt_pair_batch = build_prompt_pair_batch_from_cache( + cache=cache, + target=target, + neutral=neutral, + + ) + if self.slider_config.batch_full_slide: + # concat the prompt pairs + # this allows us to run the entire 4 part process in one shot (for slider) + self.prompt_chunk_size = 4 + concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu') + prompt_pairs += [concat_prompt_pair_batch] + else: + self.prompt_chunk_size = 1 + # do them one at a time (probably not necessary after new optimizations) + prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] + + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + # end hook_before_train_loop + + # move vae to device so we can encode on the fly + # todo cache latents + self.sd.vae.to(self.device_torch) + self.sd.vae.eval() + self.sd.vae.requires_grad_(False) + + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + + flush() + # end hook_before_train_loop + + def hook_train_loop(self, batch): + with torch.no_grad(): + ### LOOP SETUP ### + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + ### PREP REFERENCE IMAGES ### + + imgs, prompts, network_weights = batch + network_pos_weight, network_neg_weight = network_weights + + if isinstance(network_pos_weight, torch.Tensor): + network_pos_weight = network_pos_weight.item() + if isinstance(network_neg_weight, torch.Tensor): + network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + network_pos_weight += jitter_list + network_neg_weight += jitter_list + + # if items in network_weight list are tensors, convert them to floats + + dtype = get_torch_dtype(self.train_config.dtype) + imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) + # split batched images in half so left is negative and right is positive + negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + + height = positive_images.shape[2] + width = positive_images.shape[3] + batch_size = positive_images.shape[0] + + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise_positive = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + noise_negative = noise_positive.clone() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] + + flush() + + loss_float = None + loss_mirror_float = None + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # TODO allow both processed to train text encoder, for now, we just to unet and cache all text encodes + # if training text encoder enable grads, else do context of no grad + # with torch.set_grad_enabled(self.train_config.train_text_encoder): + # # text encoding + # embedding_list = [] + # # embed the prompts + # for prompt in prompts: + # embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + # embedding_list.append(embedding) + # conditional_embeds = concat_prompt_embeds(embedding_list) + # conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + + if self.train_with_dataset: + embedding_list = [] + with torch.set_grad_enabled(self.train_config.train_text_encoder): + for prompt in prompts: + # get embedding form cache + embedding = self.prompt_cache[prompt] + embedding = embedding.to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + # double up so we can do both sides of the slider + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + else: + # throw error. Not supported yet + raise Exception("Datasets and targets required for ultimate slider") + + if self.model_config.is_xl: + # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + network_multiplier_list = network_multiplier + noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + noise_list = torch.chunk(noise, 2, dim=0) + timesteps_list = torch.chunk(timesteps, 2, dim=0) + conditional_embeds_list = split_prompt_embeds(conditional_embeds) + else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] + + losses = [] + # allow to chunk it out to save vram + for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( + network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list + ): + with self.network: + assert self.network.is_active + + self.network.multiplier = network_multiplier + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + # todo add snr gamma here + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + loss_slide_float = loss.item() + + loss_float = loss.item() + losses.append(loss_float) + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + optimizer.step() + lr_scheduler.step() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0} + ) + + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/ultimate_slider_trainer/__init__.py b/extensions_built_in/ultimate_slider_trainer/__init__.py new file mode 100644 index 00000000..8c7006db --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class UltimateSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "ultimate_slider_trainer" + + # name is the name of the extension for printing + name = "Ultimate Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .UltimateSliderTrainerProcess import UltimateSliderTrainerProcess + return UltimateSliderTrainerProcess + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + UltimateSliderTrainer +] diff --git a/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml b/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml new file mode 100644 index 00000000..8b0f4734 --- /dev/null +++ b/extensions_built_in/ultimate_slider_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e41a1030..c070c904 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -4,7 +4,6 @@ from typing import List, Optional import random - class SaveConfig: def __init__(self, **kwargs): self.save_every: int = kwargs.get('save_every', 1000) @@ -87,6 +86,24 @@ class ModelConfig: raise ValueError('name_or_path must be specified') +class ReferenceDatasetConfig: + def __init__(self, **kwargs): + # can pass with a side by side pait or a folder with pos and neg folder + self.pair_folder: str = kwargs.get('pair_folder', None) + self.pos_folder: str = kwargs.get('pos_folder', None) + self.neg_folder: str = kwargs.get('neg_folder', None) + + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight)) + self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight)) + # make sure they are all absolute values no negatives + self.pos_weight = abs(self.pos_weight) + self.neg_weight = abs(self.neg_weight) + + self.target_class: str = kwargs.get('target_class', '') + self.size: int = kwargs.get('size', 512) + + class SliderTargetConfig: def __init__(self, **kwargs): self.target_class: str = kwargs.get('target_class', '') diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index a6b2a7a0..e4ddd51d 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -163,7 +163,7 @@ class PairedImageDataset(Dataset): self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if file.lower().endswith(supported_exts)] self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if - file.lower().endswith(supported_exts)] + file.lower().endswith(supported_exts)] matched_files = [] for pos_file in self.pos_file_list: @@ -177,7 +177,6 @@ class PairedImageDataset(Dataset): # remove duplicates matched_files = [t for t in (set(tuple(i) for i in matched_files))] - self.file_list = matched_files print(f" - Found {len(self.file_list)} matching pairs") else: @@ -190,6 +189,15 @@ class PairedImageDataset(Dataset): transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] ]) + def get_all_prompts(self): + prompts = [] + for index in range(len(self.file_list)): + prompts.append(self.get_prompt_item(index)) + + # remove duplicates + prompts = list(set(prompts)) + return prompts + def __len__(self): return len(self.file_list) @@ -202,19 +210,9 @@ class PairedImageDataset(Dataset): else: return default - def __getitem__(self, index): + def get_prompt_item(self, index): img_path_or_tuple = self.file_list[index] if isinstance(img_path_or_tuple, tuple): - # load both images - img_path = img_path_or_tuple[0] - img1 = exif_transpose(Image.open(img_path)).convert('RGB') - img_path = img_path_or_tuple[1] - img2 = exif_transpose(Image.open(img_path)).convert('RGB') - # combine them side by side - img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) - img.paste(img1, (0, 0)) - img.paste(img2, (img1.width, 0)) - # check if either has a prompt file path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] prompt_path = path_no_ext + '.txt' @@ -223,7 +221,6 @@ class PairedImageDataset(Dataset): prompt_path = path_no_ext + '.txt' else: img_path = img_path_or_tuple - img = exif_transpose(Image.open(img_path)).convert('RGB') # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' @@ -242,6 +239,25 @@ class PairedImageDataset(Dataset): prompt = ', '.join(prompt_split) else: prompt = self.default_prompt + return prompt + + def __getitem__(self, index): + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # load both images + img_path = img_path_or_tuple[0] + img1 = exif_transpose(Image.open(img_path)).convert('RGB') + img_path = img_path_or_tuple[1] + img2 = exif_transpose(Image.open(img_path)).convert('RGB') + # combine them side by side + img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) + img.paste(img1, (0, 0)) + img.paste(img2, (img1.width, 0)) + else: + img_path = img_path_or_tuple + img = exif_transpose(Image.open(img_path)).convert('RGB') + + prompt = self.get_prompt_item(index) height = self.size # determine width to keep aspect ratio @@ -252,4 +268,3 @@ class PairedImageDataset(Dataset): img = self.transform(img) return img, prompt, (self.neg_weight, self.pos_weight) -