diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d82ff306..7768904f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -103,7 +103,27 @@ class BaseSDTrainProcess(BaseTrainProcess): # self.sd.text_encoder.to(self.device_torch) # self.sd.tokenizer.to(self.device_torch) # TODO add clip skip - pipeline = self.sd.pipeline + if self.sd.is_xl: + pipeline = StableDiffusionXLPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=self.sd.noise_scheduler, + ) + else: + pipeline = StableDiffusionPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=self.sd.noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -162,16 +182,24 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.manual_seed(current_seed) torch.cuda.manual_seed(current_seed) - img = pipeline( - prompt=prompt, - prompt_2=prompt, - negative_prompt=neg, - negative_prompt_2=neg, - height=height, - width=width, - num_inference_steps=sample_config.sample_steps, - guidance_scale=sample_config.guidance_scale, - ).images[0] + if self.sd.is_xl: + img = pipeline( + prompt, + height=height, + width=width, + num_inference_steps=sample_config.sample_steps, + guidance_scale=sample_config.guidance_scale, + negative_prompt=neg, + ).images[0] + else: + img = pipeline( + prompt, + height=height, + width=width, + num_inference_steps=sample_config.sample_steps, + guidance_scale=sample_config.guidance_scale, + negative_prompt=neg, + ).images[0] step_num = '' if step is not None: @@ -184,6 +212,8 @@ class BaseSDTrainProcess(BaseTrainProcess): output_path = os.path.join(sample_folder, filename) img.save(output_path) + # clear pipeline and cache to reduce vram usage + del pipeline torch.cuda.empty_cache() # restore training state @@ -259,12 +289,15 @@ class BaseSDTrainProcess(BaseTrainProcess): # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) if self.network is not None: + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 # TODO handle dreambooth, fine tuning, etc self.network.save_weights( file_path, dtype=get_torch_dtype(self.save_config.dtype), metadata=save_meta ) + self.network.multiplier = prev_multiplier else: self.sd.save( file_path, @@ -340,19 +373,6 @@ class BaseSDTrainProcess(BaseTrainProcess): else: return None - def predict_noise_xl( - self, - latents: torch.FloatTensor, - positive_prompt: str, - negative_prompt: str, - timestep: int, - guidance_scale=7.5, - guidance_rescale=0.7, - add_time_ids=None, - **kwargs, - ): - pass - def predict_noise( self, latents: torch.FloatTensor, diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index dd079c3e..90535ed8 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -47,6 +47,8 @@ class EncodedPromptPair: neutral, both_targets, empty_prompt, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=1.0, weight=1.0 ): self.target_class = target_class @@ -57,6 +59,8 @@ class EncodedPromptPair: self.neutral = neutral self.empty_prompt = empty_prompt self.both_targets = both_targets + self.multiplier = multiplier + self.action: int = action self.weight = weight # simulate torch to for tensors @@ -180,6 +184,18 @@ class TrainSliderProcess(BaseSDTrainProcess): if cache[p] is None: cache[p] = self.sd.encode_prompt(p).to(device="cpu", dtype=torch.float32) + erase_negative = len(target.positive.strip()) == 0 + enhance_positive = len(target.negative.strip()) == 0 + + both = not erase_negative and not enhance_positive + + if erase_negative and enhance_positive: + raise ValueError("target must have at least one of positive or negative or both") + # for slider we need to have an enhancer, an eraser, and then + # an inverse with negative weights to balance the network + # if we don't do this, we will get different contrast and focus. + # we only perform actions of enhancing and erasing on the negative + # todo work on way to do all of this in one shot if self.slider_config.prompt_tensors: print(f"Saving prompt tensors to {self.slider_config.prompt_tensors}") state_dict = {} @@ -192,28 +208,115 @@ class TrainSliderProcess(BaseSDTrainProcess): 'fp16')) save_file(state_dict, self.slider_config.prompt_tensors) - self.print("Encoding complete. Building prompt pairs..") - for neutral in self.prompt_txt_list: + prompt_pairs = [] + for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): for target in self.slider_config.targets: - both_prompts_list = [ - f"{target.positive} {target.negative}", - f"{target.negative} {target.positive}", - ] - # randomly pick one of the both prompts to prevent bias - both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()] - prompt_pair = EncodedPromptPair( - positive_target=cache[f"{target.positive}"], - positive_target_with_neutral=cache[f"{target.positive} {neutral}"], - negative_target=cache[f"{target.negative}"], - negative_target_with_neutral=cache[f"{target.negative} {neutral}"], - neutral=cache[neutral], - both_targets=cache[both_prompts], - empty_prompt=cache[""], - target_class=cache[f"{target.target_class}"], - weight=target.weight, - ).to(device="cpu", dtype=torch.float32) - self.prompt_pairs.append(prompt_pair) + if both or erase_negative: + prompt_pairs += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + empty_prompt=cache[""], + weight=target.weight + ), + ] + if both or enhance_positive: + prompt_pairs += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + empty_prompt=cache[""], + weight=target.weight + ), + ] + if both or enhance_positive: + prompt_pairs += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive_target=cache[f"{target.negative}"], + positive_target_with_neutral=cache[f"{target.negative} {neutral}"], + negative_target=cache[f"{target.positive}"], + negative_target_with_neutral=cache[f"{target.positive} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + if both or erase_negative: + prompt_pairs += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + empty_prompt=cache[""], + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_pairs += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier + ) + ] + # self.print("Encoding complete. Building prompt pairs..") + # for neutral in self.prompt_txt_list: + # for target in self.slider_config.targets: + # both_prompts_list = [ + # f"{target.positive} {target.negative}", + # f"{target.negative} {target.positive}", + # ] + # # randomly pick one of the both prompts to prevent bias + # both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()] + # + # prompt_pair = EncodedPromptPair( + # positive_target=cache[f"{target.positive}"], + # positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + # negative_target=cache[f"{target.negative}"], + # negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + # neutral=cache[neutral], + # both_targets=cache[both_prompts], + # empty_prompt=cache[""], + # target_class=cache[f"{target.target_class}"], + # weight=target.weight, + # ).to(device="cpu", dtype=torch.float32) + # self.prompt_pairs.append(prompt_pair) # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling @@ -224,7 +327,8 @@ class TrainSliderProcess(BaseSDTrainProcess): else: self.sd.text_encoder.to("cpu") self.prompt_cache = cache - + self.prompt_pairs = prompt_pairs + self.anchor_pairs = anchor_pairs flush() # end hook_before_train_loop @@ -243,6 +347,13 @@ class TrainSliderProcess(BaseSDTrainProcess): torch.randint(0, len(self.slider_config.resolutions), (1,)).item() ] + target_class = prompt_pair.target_class + neutral = prompt_pair.neutral + negative = prompt_pair.negative_target + positive = prompt_pair.positive_target + weight = prompt_pair.weight + multiplier = prompt_pair.multiplier + unet = self.sd.unet noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer @@ -250,18 +361,20 @@ class TrainSliderProcess(BaseSDTrainProcess): loss_function = torch.nn.MSELoss() def get_noise_pred(p, n, gs, cts, dn): - return self.sd.pipeline.predict_noise( + return self.predict_noise( latents=dn, - prompt_embeds=p.text_embeds, - negative_prompt_embeds=n.text_embeds, - pooled_prompt_embeds=p.pooled_embeds, - negative_pooled_prompt_embeds=n.pooled_embeds, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # negative prompt + n, # positive prompt + self.train_config.batch_size, + ), timestep=cts, guidance_scale=gs, - num_images_per_prompt=self.train_config.batch_size, - num_inference_steps=1000, ) + # set network multiplier + self.network.multiplier = multiplier + with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -284,40 +397,20 @@ class TrainSliderProcess(BaseSDTrainProcess): latents = noise * self.sd.noise_scheduler.init_noise_sigma latents = latents.to(self.device_torch, dtype=dtype) - denoised_fraction = timesteps_to / (self.train_config.max_denoising_steps + 1) - self.sd.pipeline.to(self.device_torch) - torch.set_default_device(self.device_torch) - self.sd.pipeline.set_progress_bar_config(disable=True) - with self.network: assert self.network.is_active - self.network.multiplier = 1.0 - POS_denoised_latents = self.sd.pipeline( - num_inference_steps=self.train_config.max_denoising_steps, - denoising_end=denoised_fraction, - latents=latents, - prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, - negative_prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, - pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, - negative_pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, - output_type="latent", - num_images_per_prompt=self.train_config.batch_size, + self.network.multiplier = multiplier + denoised_latents = self.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + positive, # unconditional + target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, guidance_scale=3, - ).images.to(self.device_torch, dtype=dtype) - - self.network.multiplier = -1.0 - NEG_denoised_latents = self.sd.pipeline( - num_inference_steps=self.train_config.max_denoising_steps, - denoising_end=denoised_fraction, - latents=latents, - prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, - negative_prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, - pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, - negative_pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, - output_type="latent", - num_images_per_prompt=self.train_config.batch_size, - guidance_scale=3, - ).images.to(self.device_torch, dtype=dtype) + ) noise_scheduler.set_timesteps(1000) @@ -325,103 +418,78 @@ class TrainSliderProcess(BaseSDTrainProcess): int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] - assert not self.network.is_active + positive_latents = get_noise_pred( + positive, negative, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + neutral_latents = get_noise_pred( + positive, neutral, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) - # POSITIVE LATENTS - POS_positive_latents = get_noise_pred( - prompt_pair.negative_target_with_neutral, - prompt_pair.positive_target_with_neutral, - 1, current_timestep, POS_denoised_latents, - ) - NEG_positive_latents = get_noise_pred( - prompt_pair.positive_target_with_neutral, - prompt_pair.negative_target_with_neutral, - 1, current_timestep, NEG_denoised_latents, - ) + unconditional_latents = get_noise_pred( + positive, positive, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + anchor_loss = None + if len(self.anchor_pairs) > 0: + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + with torch.no_grad(): + anchor_target_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + with self.network: + # anchor whatever weight prompt pair is using + pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 + self.network.multiplier = anchor.multiplier * pos_nem_mult - # NEUTRAL LATENTS - POS_neutral_latents = get_noise_pred( - prompt_pair.neutral, - prompt_pair.positive_target_with_neutral, - 1, current_timestep, POS_denoised_latents, - ) - NEG_neutral_latents = get_noise_pred( - prompt_pair.neutral, - prompt_pair.negative_target_with_neutral, - 1, current_timestep, NEG_denoised_latents, - ) + anchor_pred_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) - - # UNCONDITIONAL LATENTS - POS_unconditional_latents = get_noise_pred( - prompt_pair.positive_target_with_neutral, - prompt_pair.positive_target_with_neutral, - 1, current_timestep, POS_denoised_latents, - ) - NEG_unconditional_latents = get_noise_pred( - prompt_pair.negative_target_with_neutral, - prompt_pair.negative_target_with_neutral, - 1, current_timestep, NEG_denoised_latents, - ) - - - # start grads - self.optimizer.zero_grad() + self.network.multiplier = prompt_pair.multiplier with self.network: - assert self.network.is_active - self.network.multiplier = 1.0 - POS_target_latents = get_noise_pred( - prompt_pair.negative_target_with_neutral, - prompt_pair.positive_target_with_neutral, - 1, current_timestep, POS_denoised_latents, + self.network.multiplier = prompt_pair.multiplier + target_latents = get_noise_pred( + positive, target_class, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + # if self.logging_config.verbose: + # self.print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + if len(self.anchor_pairs) > 0: + anchor_target_noise.requires_grad = False + anchor_loss = loss_function( + anchor_target_noise, + anchor_pred_noise, ) - - self.network.multiplier = -1.0 - NEG_target_latents = get_noise_pred( - prompt_pair.positive_target_with_neutral, - prompt_pair.negative_target_with_neutral, - 1, current_timestep, NEG_denoised_latents, - ) - - POS_positive_latents.requires_grad = False - NEG_positive_latents.requires_grad = False - POS_neutral_latents.requires_grad = False - NEG_neutral_latents.requires_grad = False - POS_unconditional_latents.requires_grad = False - NEG_unconditional_latents.requires_grad = False - + erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE guidance_scale = 1.0 - POS_offset = guidance_scale * (POS_positive_latents - POS_unconditional_latents) - NEG_offset = guidance_scale * (NEG_positive_latents - NEG_unconditional_latents) + offset = guidance_scale * (positive_latents - unconditional_latents) - erase = True + offset_neutral = neutral_latents + if erase: + offset_neutral -= offset + else: + # enhance + offset_neutral += offset - POS_offset_neutral = POS_neutral_latents - NEG_offset_neutral = NEG_neutral_latents - # if erase: - # POS_offset_neutral -= POS_offset - # NEG_offset_neutral -= NEG_offset - # else: - # # enhance - # POS_offset_neutral += POS_offset - # NEG_offset_neutral += NEG_offset + loss = loss_function( + target_latents, + offset_neutral, + ) * weight - POS_erase_loss = loss_function( - POS_target_latents, - POS_neutral_latents - POS_offset, - ) * prompt_pair.weight + loss_slide = loss.item() - NEG_erase_loss = loss_function( - NEG_target_latents, - NEG_neutral_latents - NEG_offset, - ) * prompt_pair.weight - - - loss = (POS_erase_loss + NEG_erase_loss) * 0.5 + if anchor_loss is not None: + loss += anchor_loss loss_float = loss.item() @@ -432,28 +500,11 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler.step() del ( - # denoised_latents, - POS_denoised_latents, - NEG_denoised_latents, - # positive_neg_noise_prediction, - POS_positive_latents, - NEG_positive_latents, - # neutral_noise_prediction, - POS_neutral_latents, - NEG_neutral_latents, - # unconditional_noise_prediction, - POS_unconditional_latents, - NEG_unconditional_latents, - # target_noise_prediction, - POS_target_latents, - NEG_target_latents, - # offset, - POS_offset, - NEG_offset, - # offset_neutral, - POS_offset_neutral, - NEG_offset_neutral, - + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, ) # move back to cpu prompt_pair.to("cpu") @@ -463,12 +514,11 @@ class TrainSliderProcess(BaseSDTrainProcess): self.network.multiplier = 1.0 loss_dict = OrderedDict( - { - 'loss': loss.item(), - 'l+er': POS_erase_loss.item(), - 'l-er': NEG_erase_loss.item(), - }, + {'loss': loss_float}, ) + if anchor_loss is not None: + loss_dict['sl_l'] = loss_slide + loss_dict['an_l'] = anchor_loss.item() return loss_dict # end hook_train_loop