From 7782caa468573462b4bb3010b8221d83473a37dc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 10 Nov 2023 12:18:08 -0700 Subject: [PATCH] Varous bug fixes. Finalized targeted guidance algo --- extensions_built_in/sd_trainer/SDTrainer.py | 114 +++++++------------- jobs/process/BaseSDTrainProcess.py | 24 +---- jobs/process/TrainSliderProcess.py | 33 +++--- toolkit/dataloader_mixins.py | 4 +- toolkit/embedding.py | 55 +++++----- 5 files changed, 92 insertions(+), 138 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 18025c29..0dcba154 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -187,117 +187,81 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ): with torch.no_grad(): - conditional_noisy_latents = noisy_latents + # Perform targeted guidance (working title) + conditional_noisy_latents = noisy_latents # target images dtype = get_torch_dtype(self.train_config.dtype) if batch.unconditional_latents is not None: - # Encode the unconditional image into latents + # unconditional latents are the "neutral" images. Add noise here identical to + # the noise added to the conditional latents, at the same timesteps unconditional_noisy_latents = self.sd.noise_scheduler.add_noise( batch.unconditional_latents, noise, timesteps ) - # was_network_active = self.network.is_active + # calculate the differential between our conditional (target image) and out unconditional (neutral image) + target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents + target_differential_noise = target_differential_noise.detach() + + # add the target differential to the target latents as if it were noise with the scheduler, scaled to + # the current timestep. Scaling the noise here is important as it scales our guidance to the current + # timestep. This is the key to making the guidance work. + guidance_latents = self.sd.noise_scheduler.add_noise( + conditional_noisy_latents, + target_differential_noise, + timesteps + ) + + # Disable the LoRA network so we can predict parent network knowledge without it self.network.is_active = False self.sd.unet.eval() - # calculate the differential between our conditional (target image) and out unconditional ("bad" image) - target_differential = unconditional_noisy_latents - conditional_noisy_latents - - # scale the target differential by the scheduler - # todo, scale it the right way - # target_differential = self.sd.noise_scheduler.add_noise( - # torch.zeros_like(target_differential), - # target_differential, - # timesteps - # ) - - target_differential = target_differential.detach() - - # add the target differential to the target latents as if it were noise with the scheduler scaled to - # the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done - # properly - # guidance_latents = self.sd.noise_scheduler.add_noise( - # conditional_noisy_latents, - # target_differential, - # timesteps - # ) - - # guidance_latents = conditional_noisy_latents + target_differential - # target_noise = conditional_noisy_latents + target_differential - - # With LoRA network bypassed, predict noise to get a baseline of what the network - # wants to do with the latents + noise. Pass our target latents here for the input. - target_unconditional = self.sd.predict_noise( - latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. + # This acts as our control to preserve the unaltered parts of the image. + baseline_prediction = self.sd.predict_noise( + latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here ).detach() - target_conditional = self.sd.predict_noise( - latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ).detach() - - # we calculate the networks current knowledge so we do not overlearn what we know - current_knowledge = target_unconditional - target_conditional - - # we now have the differential noise prediction needed to create our convergence target - target_unknown_knowledge = target_differential - current_knowledge # turn the LoRA network back on. self.sd.unet.train() self.network.is_active = True self.network.multiplier = network_weight_list - # with LoRA active, predict the noise with the scaled differential latents added. This will allow us - # the opportunity to predict the differential + noise that was added to the latents. - prediction_unconditional = self.sd.predict_noise( - latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + # do our prediction with LoRA active on the scaled guidance latents + prediction = self.sd.predict_noise( + latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here ) - # remove the baseline conditional prediction. This will leave only the divergence from the baseline and - # the prediction of the added differential noise - # prediction_positive = prediction_unconditional - target_unconditional - prediction_positive = target_unconditional - prediction_unconditional + # remove the baseline prediction from our prediction to get the differential between the two + # all that should be left is the differential between the conditional and unconditional images + pred_differential_noise = prediction - baseline_prediction # for loss, we target ONLY the unscaled differential between our conditional and unconditional latents - # this is the diffusion training process. + # not the timestep scaled noise that was added. This is the diffusion training process. # This will guide the network to make identical predictions it previously did for everything EXCEPT our - # differential between the conditional and unconditional images - - positive_loss = torch.nn.functional.mse_loss( - prediction_positive.float(), - target_unknown_knowledge.float(), + # differential between the conditional and unconditional images (target) + loss = torch.nn.functional.mse_loss( + pred_differential_noise.float(), + target_differential_noise.float(), reduction="none" ) - # add adain loss - positive_loss = positive_loss + loss = loss.mean([1, 2, 3]) + loss = self.apply_snr(loss, timesteps) + loss = loss.mean() + loss.backward() - positive_loss = positive_loss.mean([1, 2, 3]) - - # positive_loss = positive_loss + adain_loss.mean([1, 2, 3]) - # send it backwards BEFORE switching network polarity - positive_loss = self.apply_snr(positive_loss, timesteps) - positive_loss = positive_loss.mean() - positive_loss.backward() - # loss = positive_loss.detach() + negative_loss.detach() - loss = positive_loss.detach() - - # add a grad so other backward does not fail + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() loss.requires_grad_(True) - # restore network - self.network.multiplier = network_weight_list - return loss def get_prior_prediction( diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 01e57b1b..1c6453be 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # load last saved weights if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + + # self.step_num = self.embedding.step + # self.start_step = self.step_num params.append({ 'params': self.embedding.get_trainable_params(), 'lr': self.train_config.embedding_lr @@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() - if self.embed_config is not None: - self.embedding = Embedding( - sd=self.sd, - embed_config=self.embed_config - ) - latest_save_path = self.get_latest_save_path(self.embed_config.trigger) - # load last saved weights - if latest_save_path is not None: - self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) - - # resume state from embedding - self.step_num = self.embedding.step - self.start_step = self.step_num - - params = self.get_params() - if not params: - # set trainable params - params = self.embedding.get_trainable_params() - - flush() - if self.adapter_config is not None: self.setup_adapter() # set trainable params diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index f412d1ee..9f18e6bc 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess): with torch.no_grad(): adapter_images = None + self.sd.unet.eval() # for a complete slider, the batch size is 4 to begin with now true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size @@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess): latents = noise * self.sd.noise_scheduler.init_noise_sigma latents = latents.to(self.device_torch, dtype=dtype) - with self.network: - assert self.network.is_active - # pass the multiplier list to the network - self.network.multiplier = prompt_pair.multiplier_list - denoised_latents = self.sd.diffuse_some_steps( - latents, # pass simple noise latents - train_tools.concat_prompt_embeddings( - prompt_pair.positive_target, # unconditional - prompt_pair.target_class, # target - self.train_config.batch_size, - ), - start_timesteps=0, - total_timesteps=timesteps_to, - guidance_scale=3, - ) + assert not self.network.is_active + self.sd.unet.eval() + # pass the multiplier list to the network + self.network.multiplier = prompt_pair.multiplier_list + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # unconditional + prompt_pair.target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + noise_scheduler.set_timesteps(1000) @@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess): denoised_latents = denoised_latents.detach() self.sd.set_device_state(self.train_slider_device_state) + self.sd.unet.train() # start accumulating gradients self.optimizer.zero_grad(set_to_none=True) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index d82d4144..fc26d77d 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -775,9 +775,9 @@ class PoiFileItemDTOMixin: with open(caption_path, 'r', encoding='utf-8') as f: json_data = json.load(f) if 'poi' not in json_data: - raise Exception(f"Error: poi not found in caption file: {caption_path}") + print(f"Warning: poi not found in caption file: {caption_path}") if self.poi not in json_data['poi']: - raise Exception(f"Error: poi not found in caption file: {caption_path}") + print(f"Warning: poi not found in caption file: {caption_path}") # poi has, x, y, width, height # do full image if no poi self.poi_x = 0 diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 142b3913..31ac4ce2 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -47,12 +47,15 @@ class Embedding: self.placeholder_token_ids = [] self.embedding_tokens = [] + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.embed_config.tokens} tokens to tokenizer") + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): num_added_tokens = tokenizer.add_tokens(placeholder_tokens) if num_added_tokens != self.embed_config.tokens: raise ValueError( f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" ) # Convert the initializer_token, placeholder_token to ids @@ -115,10 +118,10 @@ class Embedding: def _set_vec(self, new_vector, text_encoder_idx=0): # shape is (1, 768) for SD 1.5 for 1 token - token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data for i in range(new_vector.shape[0]): # apply the weights to the placeholder tokens while preserving gradient - token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone() + token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone() # make setter and getter for vec @property @@ -249,30 +252,32 @@ class Embedding: else: return - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, - '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise Exception( - f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - if 'step' in data: - self.step = int(data['step']) - if self.sd.is_xl: self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32) self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32) + if 'step' in data: + self.step = int(data['step']) else: + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, + '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception( + f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + if 'step' in data: + self.step = int(data['step']) + self.vec = emb.detach().to(device, dtype=torch.float32)