diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 7768904f..8a7e891c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -3,6 +3,8 @@ import time from collections import OrderedDict import os +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer from toolkit.paths import REPOS_ROOT @@ -383,6 +385,7 @@ class BaseSDTrainProcess(BaseTrainProcess): add_time_ids=None, **kwargs, ): + if self.sd.is_xl: if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) @@ -407,25 +410,31 @@ class BaseSDTrainProcess(BaseTrainProcess): # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - # noise_pred = rescale_noise_cfg( - # noise_pred, noise_pred_text, guidance_rescale=guidance_rescale - # ) - - noise_pred = guided_target + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) else: - noise_pred = train_util.predict_noise( - self.sd.unet, - self.sd.noise_scheduler, + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.sd.unet( + latent_model_input, timestep, - latents, - text_embeddings.text_embeds if hasattr(text_embeddings, 'text_embeds') else text_embeddings, - guidance_scale=guidance_scale + encoder_hidden_states=text_embeddings.text_embeds, + ).sample + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond ) return noise_pred diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 1a9c8614..0913eebf 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -40,6 +40,7 @@ class EncodedPromptPair: def __init__( self, target_class, + target_class_with_neutral, positive_target, positive_target_with_neutral, negative_target, @@ -52,6 +53,7 @@ class EncodedPromptPair: weight=1.0 ): self.target_class = target_class + self.target_class_with_neutral = target_class_with_neutral self.positive_target = positive_target self.positive_target_with_neutral = positive_target_with_neutral self.negative_target = negative_target @@ -171,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess): for target in self.slider_config.targets: prompt_list = [ f"{target.target_class}", # target_class + f"{target.target_class} {neutral}", # target_class with neutral f"{target.positive}", # positive_target f"{target.positive} {neutral}", # positive_target with neutral f"{target.negative}", # negative_target @@ -217,6 +220,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # erase standard EncodedPromptPair( target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.positive}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"], negative_target=cache[f"{target.negative}"], @@ -234,6 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # enhance standard, swap pos neg EncodedPromptPair( target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.negative}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"], negative_target=cache[f"{target.positive}"], @@ -251,6 +256,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # erase inverted EncodedPromptPair( target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.negative}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"], negative_target=cache[f"{target.positive}"], @@ -268,6 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # enhance inverted EncodedPromptPair( target_class=cache[target.target_class], + target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.positive}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"], negative_target=cache[f"{target.negative}"], @@ -299,28 +306,7 @@ class TrainSliderProcess(BaseSDTrainProcess): multiplier=anchor.multiplier ) ] - # self.print("Encoding complete. Building prompt pairs..") - # for neutral in self.prompt_txt_list: - # for target in self.slider_config.targets: - # both_prompts_list = [ - # f"{target.positive} {target.negative}", - # f"{target.negative} {target.positive}", - # ] - # # randomly pick one of the both prompts to prevent bias - # both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()] - # - # prompt_pair = EncodedPromptPair( - # positive_target=cache[f"{target.positive}"], - # positive_target_with_neutral=cache[f"{target.positive} {neutral}"], - # negative_target=cache[f"{target.negative}"], - # negative_target_with_neutral=cache[f"{target.negative} {neutral}"], - # neutral=cache[neutral], - # both_targets=cache[both_prompts], - # empty_prompt=cache[""], - # target_class=cache[f"{target.target_class}"], - # weight=target.weight, - # ).to(device="cpu", dtype=torch.float32) - # self.prompt_pairs.append(prompt_pair) + # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling @@ -340,7 +326,8 @@ class TrainSliderProcess(BaseSDTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) # get random multiplier between 1 and 3 - rand_weight = torch.rand((1,)).item() * 2 + 1 + rand_weight = 1 + # rand_weight = torch.rand((1,)).item() * 2 + 1 # get a random pair prompt_pair: EncodedPromptPair = self.prompt_pairs[ @@ -367,12 +354,12 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() - def get_noise_pred(p, n, gs, cts, dn): + def get_noise_pred(neg, pos, gs, cts, dn): return self.predict_noise( latents=dn, text_embeddings=train_tools.concat_prompt_embeddings( - p, # negative prompt - n, # positive prompt + neg, # negative prompt + pos, # positive prompt self.train_config.batch_size, ), timestep=cts, @@ -410,8 +397,8 @@ class TrainSliderProcess(BaseSDTrainProcess): denoised_latents = self.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( - positive, # unconditional - target_class, # target + prompt_pair.positive_target, # unconditional + prompt_pair.target_class, # target self.train_config.batch_size, ), start_timesteps=0, @@ -426,15 +413,27 @@ class TrainSliderProcess(BaseSDTrainProcess): ] positive_latents = get_noise_pred( - positive, negative, 1, current_timestep, denoised_latents + prompt_pair.positive_target, # negative prompt + prompt_pair.negative_target, # positive prompt + 1, + current_timestep, + denoised_latents ).to("cpu", dtype=torch.float32) neutral_latents = get_noise_pred( - positive, neutral, 1, current_timestep, denoised_latents + prompt_pair.positive_target, # negative prompt + prompt_pair.empty_prompt, # positive prompt (normally neutral + 1, + current_timestep, + denoised_latents ).to("cpu", dtype=torch.float32) unconditional_latents = get_noise_pred( - positive, positive, 1, current_timestep, denoised_latents + prompt_pair.positive_target, # negative prompt + prompt_pair.positive_target, # positive prompt + 1, + current_timestep, + denoised_latents ).to("cpu", dtype=torch.float32) anchor_loss = None @@ -461,7 +460,11 @@ class TrainSliderProcess(BaseSDTrainProcess): with self.network: self.network.multiplier = prompt_pair.multiplier * rand_weight target_latents = get_noise_pred( - positive, target_class, 1, current_timestep, denoised_latents + prompt_pair.positive_target, + prompt_pair.target_class, + 1, + current_timestep, + denoised_latents ).to("cpu", dtype=torch.float32) # if self.logging_config.verbose: