From 1e50b39442b4cb9ad5c571319807daff6491f33a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 28 Jul 2023 18:11:10 -0600 Subject: [PATCH] Work on slider rework --- jobs/process/BaseSDTrainProcess.py | 5 +- jobs/process/TrainSliderProcess.py | 305 ++++++++++++++++------------- toolkit/optimizer.py | 11 ++ 3 files changed, 187 insertions(+), 134 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 71dbd7ac..d82ff306 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -617,7 +617,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.lr_scheduler, optimizer, max_iterations=self.train_config.steps, - lr_min=self.train_config.lr / 100, # not sure why leco did this, but ill do it to + lr_min=self.train_config.lr / 100, ) self.lr_scheduler = lr_scheduler @@ -651,7 +651,8 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### loss_dict = self.hook_train_loop() - if self.train_config.optimizer.startswith('dadaptation'): + if self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index c0d15ae1..dd079c3e 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -1,5 +1,6 @@ # ref: # - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import random import time from collections import OrderedDict import os @@ -38,14 +39,17 @@ def flush(): class EncodedPromptPair: def __init__( self, + target_class, positive_target, positive_target_with_neutral, negative_target, negative_target_with_neutral, neutral, both_targets, - empty_prompt + empty_prompt, + weight=1.0 ): + self.target_class = target_class self.positive_target = positive_target self.positive_target_with_neutral = positive_target_with_neutral self.negative_target = negative_target @@ -53,9 +57,11 @@ class EncodedPromptPair: self.neutral = neutral self.empty_prompt = empty_prompt self.both_targets = both_targets + self.weight = weight # simulate torch to for tensors def to(self, *args, **kwargs): + self.target_class = self.target_class.to(*args, **kwargs) self.positive_target = self.positive_target.to(*args, **kwargs) self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) self.negative_target = self.negative_target.to(*args, **kwargs) @@ -120,6 +126,14 @@ class TrainSliderProcess(BaseSDTrainProcess): cache = PromptEmbedsCache() + if not self.slider_config.prompt_tensors: + # shuffle + random.shuffle(self.prompt_txt_list) + # trim to max steps + self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps] + # trim list to our max steps + + # get encoded latents for our prompts with torch.no_grad(): if self.slider_config.prompt_tensors is not None: @@ -129,7 +143,7 @@ class TrainSliderProcess(BaseSDTrainProcess): self.print(f"Loading prompt tensors from {self.slider_config.prompt_tensors}") prompt_tensors = load_file(self.slider_config.prompt_tensors, device='cpu') # add them to the cache - for prompt_txt, prompt_tensor in prompt_tensors.items(): + for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): if prompt_txt.startswith("te:"): prompt = prompt_txt[3:] # text_embeds @@ -152,6 +166,7 @@ class TrainSliderProcess(BaseSDTrainProcess): for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): for target in self.slider_config.targets: prompt_list = [ + f"{target.target_class}", # target_class f"{target.positive}", # positive_target f"{target.positive} {neutral}", # positive_target with neutral f"{target.negative}", # negative_target @@ -195,6 +210,8 @@ class TrainSliderProcess(BaseSDTrainProcess): neutral=cache[neutral], both_targets=cache[both_prompts], empty_prompt=cache[""], + target_class=cache[f"{target.target_class}"], + weight=target.weight, ).to(device="cpu", dtype=torch.float32) self.prompt_pairs.append(prompt_pair) @@ -232,6 +249,19 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() + def get_noise_pred(p, n, gs, cts, dn): + return self.sd.pipeline.predict_noise( + latents=dn, + prompt_embeds=p.text_embeds, + negative_prompt_embeds=n.text_embeds, + pooled_prompt_embeds=p.pooled_embeds, + negative_pooled_prompt_embeds=n.pooled_embeds, + timestep=cts, + guidance_scale=gs, + num_images_per_prompt=self.train_config.batch_size, + num_inference_steps=1000, + ) + with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -259,149 +289,139 @@ class TrainSliderProcess(BaseSDTrainProcess): torch.set_default_device(self.device_torch) self.sd.pipeline.set_progress_bar_config(disable=True) - # get generate semi denoised latents without network - # only neutrap in positive and both targets in negative - assert not self.network.is_active - # denoised_latents = self.sd.pipeline( - # num_inference_steps=self.train_config.max_denoising_steps, - # denoising_end=denoised_fraction, - # latents=latents, - # prompt_embeds=prompt_pair.neutral.text_embeds, - # negative_prompt_embeds=prompt_pair.both_targets.text_embeds, - # pooled_prompt_embeds=prompt_pair.neutral.pooled_embeds, - # negative_pooled_prompt_embeds=prompt_pair.both_targets.pooled_embeds, - # output_type="latent", - # num_images_per_prompt=self.train_config.batch_size, - # guidance_scale=3, - # ).images.to(self.device_torch, dtype=dtype) + with self.network: + assert self.network.is_active + self.network.multiplier = 1.0 + POS_denoised_latents = self.sd.pipeline( + num_inference_steps=self.train_config.max_denoising_steps, + denoising_end=denoised_fraction, + latents=latents, + prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, + negative_prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, + pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, + negative_pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, + output_type="latent", + num_images_per_prompt=self.train_config.batch_size, + guidance_scale=3, + ).images.to(self.device_torch, dtype=dtype) + + self.network.multiplier = -1.0 + NEG_denoised_latents = self.sd.pipeline( + num_inference_steps=self.train_config.max_denoising_steps, + denoising_end=denoised_fraction, + latents=latents, + prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, + negative_prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, + pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, + negative_pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, + output_type="latent", + num_images_per_prompt=self.train_config.batch_size, + guidance_scale=3, + ).images.to(self.device_torch, dtype=dtype) noise_scheduler.set_timesteps(1000) + current_timestep = noise_scheduler.timesteps[ int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] - denoised_latents = noise - # neutral prediction - neutral_noise_prediction = self.sd.pipeline.predict_noise( - latents=denoised_latents, - prompt_embeds=prompt_pair.neutral.text_embeds, - negative_prompt_embeds=prompt_pair.empty_prompt.text_embeds, - pooled_prompt_embeds=prompt_pair.neutral.pooled_embeds, - negative_pooled_prompt_embeds=prompt_pair.both_targets.pooled_embeds, - timestep=current_timestep, - guidance_scale=1, - num_images_per_prompt=self.train_config.batch_size, - num_inference_steps=1000, + assert not self.network.is_active + + + # POSITIVE LATENTS + POS_positive_latents = get_noise_pred( + prompt_pair.negative_target_with_neutral, + prompt_pair.positive_target_with_neutral, + 1, current_timestep, POS_denoised_latents, + ) + NEG_positive_latents = get_noise_pred( + prompt_pair.positive_target_with_neutral, + prompt_pair.negative_target_with_neutral, + 1, current_timestep, NEG_denoised_latents, + ) + + + # NEUTRAL LATENTS + POS_neutral_latents = get_noise_pred( + prompt_pair.neutral, + prompt_pair.positive_target_with_neutral, + 1, current_timestep, POS_denoised_latents, + ) + NEG_neutral_latents = get_noise_pred( + prompt_pair.neutral, + prompt_pair.negative_target_with_neutral, + 1, current_timestep, NEG_denoised_latents, + ) + + + # UNCONDITIONAL LATENTS + POS_unconditional_latents = get_noise_pred( + prompt_pair.positive_target_with_neutral, + prompt_pair.positive_target_with_neutral, + 1, current_timestep, POS_denoised_latents, + ) + NEG_unconditional_latents = get_noise_pred( + prompt_pair.negative_target_with_neutral, + prompt_pair.negative_target_with_neutral, + 1, current_timestep, NEG_denoised_latents, ) - # with self.network: - # assert self.network.is_active - # self.network.multiplier = 1.0 - # - # positive_pos_noise_prediction = self.sd.pipeline.predict_noise( - # latents=denoised_latents, - # prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, - # negative_prompt_embeds=prompt_pair.negative_target.text_embeds, - # pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, - # negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, - # timestep=current_timestep, - # guidance_scale=1, - # num_images_per_prompt=self.train_config.batch_size, - # num_inference_steps=1000 - # ) - # - # self.network.multiplier = -1.0 - # - # negative_neg_noise_prediction = self.sd.pipeline.predict_noise( - # latents=denoised_latents, - # prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, - # negative_prompt_embeds=prompt_pair.positive_target.text_embeds, - # pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, - # negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, - # timestep=current_timestep, - # guidance_scale=1, - # num_images_per_prompt=self.train_config.batch_size, - # num_inference_steps=1000 - # ) # start grads self.optimizer.zero_grad() - multiplier = 5.0 - - # predict postiitive with self.network: assert self.network.is_active - self.network.multiplier = multiplier * 1.0 - - # positive_pos_noise_prediction = self.sd.pipeline.predict_noise( - # latents=denoised_latents, - # prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, - # negative_prompt_embeds=prompt_pair.negative_target.text_embeds, - # pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, - # negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, - # timestep=current_timestep, - # guidance_scale=1, - # num_images_per_prompt=self.train_config.batch_size, - # num_inference_steps=self.train_config.max_denoising_steps, - # ) - - negative_pos_noise_prediction = self.sd.pipeline.predict_noise( - latents=denoised_latents, - prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, - negative_prompt_embeds=prompt_pair.positive_target.text_embeds, - pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, - negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, - timestep=current_timestep, - guidance_scale=1, - num_images_per_prompt=self.train_config.batch_size, - num_inference_steps=1000, - ) - - self.network.multiplier = multiplier * -1.0 - - positive_neg_noise_prediction = self.sd.pipeline.predict_noise( - latents=denoised_latents, - prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, - negative_prompt_embeds=prompt_pair.negative_target.text_embeds, - pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, - negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, - timestep=current_timestep, - guidance_scale=1, - num_images_per_prompt=self.train_config.batch_size, - num_inference_steps=1000, - ) - - # negative_neg_noise_prediction = self.sd.pipeline.predict_noise( - # latents=denoised_latents, - # prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, - # negative_prompt_embeds=prompt_pair.positive_target.text_embeds, - # pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, - # negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, - # timestep=current_timestep, - # guidance_scale=1, - # num_images_per_prompt=self.train_config.batch_size, - # num_inference_steps=self.train_config.max_denoising_steps, - # ) - self.network.multiplier = 1.0 + POS_target_latents = get_noise_pred( + prompt_pair.negative_target_with_neutral, + prompt_pair.positive_target_with_neutral, + 1, current_timestep, POS_denoised_latents, + ) - neutral_noise_prediction.requires_grad = False - # positive_pos_noise_prediction.requires_grad = False - # negative_neg_noise_prediction.requires_grad = False + self.network.multiplier = -1.0 + NEG_target_latents = get_noise_pred( + prompt_pair.positive_target_with_neutral, + prompt_pair.negative_target_with_neutral, + 1, current_timestep, NEG_denoised_latents, + ) - # calculate loss - loss_shrink_pos_neg = loss_function( - negative_pos_noise_prediction, - neutral_noise_prediction, - ) + POS_positive_latents.requires_grad = False + NEG_positive_latents.requires_grad = False + POS_neutral_latents.requires_grad = False + NEG_neutral_latents.requires_grad = False + POS_unconditional_latents.requires_grad = False + NEG_unconditional_latents.requires_grad = False - loss_shrink_neg_pos = loss_function( - positive_neg_noise_prediction, - negative_pos_noise_prediction, - ) + guidance_scale = 1.0 - loss = loss_shrink_pos_neg + loss_shrink_neg_pos + POS_offset = guidance_scale * (POS_positive_latents - POS_unconditional_latents) + NEG_offset = guidance_scale * (NEG_positive_latents - NEG_unconditional_latents) + + erase = True + + POS_offset_neutral = POS_neutral_latents + NEG_offset_neutral = NEG_neutral_latents + # if erase: + # POS_offset_neutral -= POS_offset + # NEG_offset_neutral -= NEG_offset + # else: + # # enhance + # POS_offset_neutral += POS_offset + # NEG_offset_neutral += NEG_offset + + POS_erase_loss = loss_function( + POS_target_latents, + POS_neutral_latents - POS_offset, + ) * prompt_pair.weight + + NEG_erase_loss = loss_function( + NEG_target_latents, + NEG_neutral_latents - NEG_offset, + ) * prompt_pair.weight + + + loss = (POS_erase_loss + NEG_erase_loss) * 0.5 loss_float = loss.item() @@ -412,11 +432,28 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler.step() del ( - denoised_latents, - positive_neg_noise_prediction, - negative_pos_noise_prediction, - neutral_noise_prediction, - latents, + # denoised_latents, + POS_denoised_latents, + NEG_denoised_latents, + # positive_neg_noise_prediction, + POS_positive_latents, + NEG_positive_latents, + # neutral_noise_prediction, + POS_neutral_latents, + NEG_neutral_latents, + # unconditional_noise_prediction, + POS_unconditional_latents, + NEG_unconditional_latents, + # target_noise_prediction, + POS_target_latents, + NEG_target_latents, + # offset, + POS_offset, + NEG_offset, + # offset_neutral, + POS_offset_neutral, + NEG_offset_neutral, + ) # move back to cpu prompt_pair.to("cpu") @@ -426,7 +463,11 @@ class TrainSliderProcess(BaseSDTrainProcess): self.network.multiplier = 1.0 loss_dict = OrderedDict( - {'loss': loss_float}, + { + 'loss': loss.item(), + 'l+er': POS_erase_loss.item(), + 'l-er': NEG_erase_loss.item(), + }, ) return loss_dict diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index a9a9b6e7..09ebb734 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -27,6 +27,17 @@ def get_optimizer( optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params) # warn user that dadaptation is deprecated print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") + elif lower_type.startswith("prodigy"): + from prodigyopt import Prodigy + + print("Using Prodigy optimizer") + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + # let net be the neural network you want to train + # you can choose weight decay value based on your problem, 0 by default + optimizer = Prodigy(params, lr=use_lr, **optimizer_params) elif lower_type.endswith("8bit"): import bitsandbytes