From 67dfd9ced0705dc81eb1d1780e2e35e941d079eb Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 10 Aug 2023 16:20:38 -0600 Subject: [PATCH] Added inbuild plugins and made one for image referenced. WIP --- .../ImageReferenceSliderTrainerProcess.py | 94 +++++++++++++------ .../config/train.example.yaml | 92 ++++++++++++++++++ 2 files changed, 156 insertions(+), 30 deletions(-) diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index e9f7d0ef..a4816bec 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -25,6 +25,7 @@ class ReferenceSliderConfig: self.resolutions: List[int] = kwargs.get('resolutions', [512]) self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) self.target_class: int = kwargs.get('target_class', '') + self.additional_losses: List[str] = kwargs.get('additional_losses', []) class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): @@ -73,6 +74,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): pass def hook_train_loop(self, batch): + do_mirror_loss = 'mirror' in self.slider_config.additional_losses + with torch.no_grad(): imgs, prompts = batch dtype = get_torch_dtype(self.train_config.dtype) @@ -135,62 +138,89 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): timesteps = timesteps.long() # get noise - noise = self.sd.get_latent_noise( + noise_positive = self.sd.get_latent_noise( pixel_height=height, pixel_width=width, batch_size=batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) + if do_mirror_loss: + # mirror the noise + noise_negative = torch.flip(noise_positive.clone(), dims=[3]) + else: + noise_negative = noise_positive.clone() + # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps) - noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps) + noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) + noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) + noise = torch.cat([noise_positive, noise_negative], dim=0) + timesteps = torch.cat([timesteps, timesteps], dim=0) + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) + network_multiplier = [1.0, -1.0] flush() + loss_float = None + loss_slide_float = None + loss_mirror_float = None + self.optimizer.zero_grad() with self.network: assert self.network.is_active loss_list = [] - for noisy_latents, network_multiplier in zip( - [noisy_positive_latents, noisy_negative_latents], - [1.0, -1.0], - ): - # do positive first - self.network.multiplier = network_multiplier - noise_pred = get_noise_pred( - unconditional_embeds, - conditional_embeds, - 1, - timesteps, - noisy_latents - ) + # do positive first + self.network.multiplier = network_multiplier - if self.sd.is_v2: # check is vpred, don't want to track it down right now - # v-parameterization training - target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise + noise_pred = get_noise_pred( + unconditional_embeds, + conditional_embeds, + 1, + timesteps, + noisy_latents + ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + if self.sd.is_v2: # check is vpred, don't want to track it down right now + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise - # todo add snr gamma here + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - loss = loss.mean() - # back propagate loss to free ram - loss.backward() - loss_list.append(loss.item()) + # todo add snr gamma here - flush() + loss = loss.mean() + loss_slide_float = loss.item() + + + if do_mirror_loss: + noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0) + # mirror the negative + noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3]) + loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none") + loss_mirror = loss_mirror.mean([1, 2, 3]) + loss_mirror = loss_mirror.mean() + loss_mirror_float = loss_mirror.item() + loss += loss_mirror + + loss_float = loss.item() + + # back propagate loss to free ram + loss.backward() + + flush() # apply gradients optimizer.step() lr_scheduler.step() - loss_float = sum(loss_list) / len(loss_list) # reset network self.network.multiplier = 1.0 @@ -198,5 +228,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss_dict = OrderedDict( {'loss': loss_float}, ) + + if do_mirror_loss: + loss_dict['l/s'] = loss_slide_float + loss_dict['l/m'] = loss_mirror_float return loss_dict # end hook_train_loop diff --git a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml index e69de29b..301790f3 100644 --- a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml +++ b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml @@ -0,0 +1,92 @@ +--- +job: extension +config: + name: subject_turner_v1 + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers + rank: 16 + alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 1000 + lr: 5e-5 + train_unet: true + gradient_checkpointing: true + train_text_encoder: false + optimizer: "lion8bit" + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX + model: +# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors" + name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors" + # name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 20 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + resolutions: + - 512 + slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner" + target_class: "photo of a person" + + +meta: + name: "[name]" + version: '1.0' + creator: + name: Ostris - Jaret Burkett + email: jaret@ostris.com + website: https://ostris.com