diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 1d5282ef..3fc9b7db 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -16,7 +16,8 @@ sys.path.append(REPOS_ROOT) process_dict = { 'vae': 'TrainVAEProcess', - 'slider': 'TrainSliderProcess', + 'slider_dev': 'TrainSliderProcess', + 'slider': 'TrainSliderProcessOld', 'lora_hack': 'TrainLoRAHack', 'rescale_sd': 'TrainSDRescaleProcess', } diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index c7b89a2a..c0d15ae1 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -5,6 +5,9 @@ from collections import OrderedDict import os from typing import Optional +from safetensors.torch import save_file, load_file +from tqdm import tqdm + from toolkit.config_modules import SliderConfig from toolkit.paths import REPOS_ROOT import sys @@ -35,28 +38,35 @@ def flush(): class EncodedPromptPair: def __init__( self, - target_class, - positive, - negative, + positive_target, + positive_target_with_neutral, + negative_target, + negative_target_with_neutral, neutral, - width=512, - height=512, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=1.0, - weight=1.0 + both_targets, + empty_prompt ): - self.target_class = target_class - self.positive = positive - self.negative = negative + self.positive_target = positive_target + self.positive_target_with_neutral = positive_target_with_neutral + self.negative_target = negative_target + self.negative_target_with_neutral = negative_target_with_neutral self.neutral = neutral - self.width = width - self.height = height - self.action: int = action - self.multiplier = multiplier - self.weight = weight + self.empty_prompt = empty_prompt + self.both_targets = both_targets + + # simulate torch to for tensors + def to(self, *args, **kwargs): + self.positive_target = self.positive_target.to(*args, **kwargs) + self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) + self.negative_target = self.negative_target.to(*args, **kwargs) + self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) + self.neutral = self.neutral.to(*args, **kwargs) + self.empty_prompt = self.empty_prompt.to(*args, **kwargs) + self.both_targets = self.both_targets.to(*args, **kwargs) + return self -class PromptEmbedsCache: # 使いまわしたいので +class PromptEmbedsCache: prompts: dict[str, PromptEmbeds] = {} def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: @@ -84,6 +94,7 @@ class EncodedAnchor: class TrainSliderProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) + self.prompt_txt_list = None self.step_num = 0 self.start_step = 0 self.device = self.get_conf('device', self.job.device) @@ -97,115 +108,95 @@ class TrainSliderProcess(BaseSDTrainProcess): pass def hook_before_train_loop(self): + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") + + # read line by line from file + with open(self.slider_config.prompt_file, 'r') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") + cache = PromptEmbedsCache() - prompt_pairs: list[EncodedPromptPair] = [] # get encoded latents for our prompts with torch.no_grad(): - neutral = "" - for target in self.slider_config.targets: - # build the cache - for prompt in [ - target.target_class, - target.positive, - target.negative, - neutral # empty neutral - ]: - if cache[prompt] is None: - cache[prompt] = self.sd.encode_prompt(prompt) - for resolution in self.slider_config.resolutions: - width, height = resolution - only_erase = len(target.positive.strip()) == 0 - only_enhance = len(target.negative.strip()) == 0 + if self.slider_config.prompt_tensors is not None: + # check to see if it exists + if os.path.exists(self.slider_config.prompt_tensors): + # load it. + self.print(f"Loading prompt tensors from {self.slider_config.prompt_tensors}") + prompt_tensors = load_file(self.slider_config.prompt_tensors, device='cpu') + # add them to the cache + for prompt_txt, prompt_tensor in prompt_tensors.items(): + if prompt_txt.startswith("te:"): + prompt = prompt_txt[3:] + # text_embeds + text_embeds = prompt_tensor + pooled_embeds = None + # find pool embeds + if f"pe:{prompt}" in prompt_tensors: + pooled_embeds = prompt_tensors[f"pe:{prompt}"] - both = not only_erase and not only_enhance + # make it + prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) + cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) - if only_erase and only_enhance: - raise ValueError("target must have at least one of positive or negative or both") - # for slider we need to have an enhancer, an eraser, and then - # an inverse with negative weights to balance the network - # if we don't do this, we will get different contrast and focus. - # we only perform actions of enhancing and erasing on the negative - # todo work on way to do all of this in one shot + if len(cache.prompts) == 0: + print("Prompt tensors not found. Encoding prompts..") + empty_prompt = "" + # encode empty_prompt + cache[empty_prompt] = self.sd.encode_prompt(empty_prompt) - if both or only_erase: - prompt_pairs += [ - # erase standard - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.positive], - negative=cache[target.negative], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier, - weight=target.weight - ), - ] - if both or only_enhance: - prompt_pairs += [ - # enhance standard, swap pos neg - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.negative], - negative=cache[target.positive], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier, - weight=target.weight - ), - ] - if both: - prompt_pairs += [ - # erase inverted - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.negative], - negative=cache[target.positive], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier * -1.0, - weight=target.weight - ), - ] - prompt_pairs += [ - # enhance inverted - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.positive], - negative=cache[target.negative], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier * -1.0, - weight=target.weight - ), + for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): + for target in self.slider_config.targets: + prompt_list = [ + f"{target.positive}", # positive_target + f"{target.positive} {neutral}", # positive_target with neutral + f"{target.negative}", # negative_target + f"{target.negative} {neutral}", # negative_target with neutral + f"{neutral}", # neutral + f"{target.positive} {target.negative}", # both targets + f"{target.negative} {target.positive}", # both targets ] + for p in prompt_list: + # build the cache + if cache[p] is None: + cache[p] = self.sd.encode_prompt(p).to(device="cpu", dtype=torch.float32) - # setup anchors - anchor_pairs = [] - for anchor in self.slider_config.anchors: - # build the cache - for prompt in [ - anchor.prompt, - anchor.neg_prompt # empty neutral - ]: - if cache[prompt] == None: - cache[prompt] = self.sd.encode_prompt(prompt) + if self.slider_config.prompt_tensors: + print(f"Saving prompt tensors to {self.slider_config.prompt_tensors}") + state_dict = {} + for prompt_txt, prompt_embeds in cache.prompts.items(): + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", + dtype=get_torch_dtype('fp16')) + if prompt_embeds.pooled_embeds is not None: + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", + dtype=get_torch_dtype( + 'fp16')) + save_file(state_dict, self.slider_config.prompt_tensors) - anchor_pairs += [ - EncodedAnchor( - prompt=cache[anchor.prompt], - neg_prompt=cache[anchor.neg_prompt], - multiplier=anchor.multiplier - ) - ] + self.print("Encoding complete. Building prompt pairs..") + for neutral in self.prompt_txt_list: + for target in self.slider_config.targets: + both_prompts_list = [ + f"{target.positive} {target.negative}", + f"{target.negative} {target.positive}", + ] + # randomly pick one of the both prompts to prevent bias + both_prompts = both_prompts_list[torch.randint(0, 2, (1,)).item()] + + prompt_pair = EncodedPromptPair( + positive_target=cache[f"{target.positive}"], + positive_target_with_neutral=cache[f"{target.positive} {neutral}"], + negative_target=cache[f"{target.negative}"], + negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + neutral=cache[neutral], + both_targets=cache[both_prompts], + empty_prompt=cache[""], + ).to(device="cpu", dtype=torch.float32) + self.prompt_pairs.append(prompt_pair) # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling @@ -216,8 +207,7 @@ class TrainSliderProcess(BaseSDTrainProcess): else: self.sd.text_encoder.to("cpu") self.prompt_cache = cache - self.prompt_pairs = prompt_pairs - self.anchor_pairs = anchor_pairs + flush() # end hook_before_train_loop @@ -228,15 +218,13 @@ class TrainSliderProcess(BaseSDTrainProcess): prompt_pair: EncodedPromptPair = self.prompt_pairs[ torch.randint(0, len(self.prompt_pairs), (1,)).item() ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) - height = prompt_pair.height - width = prompt_pair.width - target_class = prompt_pair.target_class - neutral = prompt_pair.neutral - negative = prompt_pair.negative - positive = prompt_pair.positive - weight = prompt_pair.weight - multiplier = prompt_pair.multiplier + # get a random resolution + height, width = self.slider_config.resolutions[ + torch.randint(0, len(self.slider_config.resolutions), (1,)).item() + ] unet = self.sd.unet noise_scheduler = self.sd.noise_scheduler @@ -244,21 +232,6 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() - def get_noise_pred(p, n, gs, cts, dn): - return self.predict_noise( - latents=dn, - text_embeddings=train_tools.concat_prompt_embeddings( - p, # unconditional - n, # positive - self.train_config.batch_size, - ), - timestep=cts, - guidance_scale=gs, - ) - - # set network multiplier - self.network.multiplier = multiplier - with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -281,99 +254,154 @@ class TrainSliderProcess(BaseSDTrainProcess): latents = noise * self.sd.noise_scheduler.init_noise_sigma latents = latents.to(self.device_torch, dtype=dtype) - with self.network: - assert self.network.is_active - self.network.multiplier = multiplier - denoised_latents = self.diffuse_some_steps( - latents, # pass simple noise latents - train_tools.concat_prompt_embeddings( - positive, # unconditional - target_class, # target - self.train_config.batch_size, - ), - start_timesteps=0, - total_timesteps=timesteps_to, - guidance_scale=3, - ) + denoised_fraction = timesteps_to / (self.train_config.max_denoising_steps + 1) + self.sd.pipeline.to(self.device_torch) + torch.set_default_device(self.device_torch) + self.sd.pipeline.set_progress_bar_config(disable=True) + + # get generate semi denoised latents without network + # only neutrap in positive and both targets in negative + assert not self.network.is_active + # denoised_latents = self.sd.pipeline( + # num_inference_steps=self.train_config.max_denoising_steps, + # denoising_end=denoised_fraction, + # latents=latents, + # prompt_embeds=prompt_pair.neutral.text_embeds, + # negative_prompt_embeds=prompt_pair.both_targets.text_embeds, + # pooled_prompt_embeds=prompt_pair.neutral.pooled_embeds, + # negative_pooled_prompt_embeds=prompt_pair.both_targets.pooled_embeds, + # output_type="latent", + # num_images_per_prompt=self.train_config.batch_size, + # guidance_scale=3, + # ).images.to(self.device_torch, dtype=dtype) noise_scheduler.set_timesteps(1000) - current_timestep = noise_scheduler.timesteps[ int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] + denoised_latents = noise - positive_latents = get_noise_pred( - positive, negative, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - - neutral_latents = get_noise_pred( - positive, neutral, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - - unconditional_latents = get_noise_pred( - positive, positive, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - - anchor_loss = None - if len(self.anchor_pairs) > 0: - # get a random anchor pair - anchor: EncodedAnchor = self.anchor_pairs[ - torch.randint(0, len(self.anchor_pairs), (1,)).item() - ] - with torch.no_grad(): - anchor_target_noise = get_noise_pred( - anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - with self.network: - # anchor whatever weight prompt pair is using - pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 - self.network.multiplier = anchor.multiplier * pos_nem_mult - - anchor_pred_noise = get_noise_pred( - anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - - self.network.multiplier = prompt_pair.multiplier - - with self.network: - self.network.multiplier = prompt_pair.multiplier - target_latents = get_noise_pred( - positive, target_class, 1, current_timestep, denoised_latents - ).to("cpu", dtype=torch.float32) - - # if self.logging_config.verbose: - # self.print("target_latents:", target_latents[0, 0, :5, :5]) - - positive_latents.requires_grad = False - neutral_latents.requires_grad = False - unconditional_latents.requires_grad = False - if len(self.anchor_pairs) > 0: - anchor_target_noise.requires_grad = False - anchor_loss = loss_function( - anchor_target_noise, - anchor_pred_noise, + # neutral prediction + neutral_noise_prediction = self.sd.pipeline.predict_noise( + latents=denoised_latents, + prompt_embeds=prompt_pair.neutral.text_embeds, + negative_prompt_embeds=prompt_pair.empty_prompt.text_embeds, + pooled_prompt_embeds=prompt_pair.neutral.pooled_embeds, + negative_pooled_prompt_embeds=prompt_pair.both_targets.pooled_embeds, + timestep=current_timestep, + guidance_scale=1, + num_images_per_prompt=self.train_config.batch_size, + num_inference_steps=1000, ) - erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE - guidance_scale = 1.0 - offset = guidance_scale * (positive_latents - unconditional_latents) + # with self.network: + # assert self.network.is_active + # self.network.multiplier = 1.0 + # + # positive_pos_noise_prediction = self.sd.pipeline.predict_noise( + # latents=denoised_latents, + # prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, + # negative_prompt_embeds=prompt_pair.negative_target.text_embeds, + # pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, + # negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, + # timestep=current_timestep, + # guidance_scale=1, + # num_images_per_prompt=self.train_config.batch_size, + # num_inference_steps=1000 + # ) + # + # self.network.multiplier = -1.0 + # + # negative_neg_noise_prediction = self.sd.pipeline.predict_noise( + # latents=denoised_latents, + # prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, + # negative_prompt_embeds=prompt_pair.positive_target.text_embeds, + # pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, + # negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, + # timestep=current_timestep, + # guidance_scale=1, + # num_images_per_prompt=self.train_config.batch_size, + # num_inference_steps=1000 + # ) - offset_neutral = neutral_latents - if erase: - offset_neutral -= offset - else: - # enhance - offset_neutral += offset + # start grads + self.optimizer.zero_grad() - loss = loss_function( - target_latents, - offset_neutral, - ) * weight + multiplier = 5.0 - loss_slide = loss.item() + # predict postiitive + with self.network: + assert self.network.is_active + self.network.multiplier = multiplier * 1.0 - if anchor_loss is not None: - loss += anchor_loss + # positive_pos_noise_prediction = self.sd.pipeline.predict_noise( + # latents=denoised_latents, + # prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, + # negative_prompt_embeds=prompt_pair.negative_target.text_embeds, + # pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, + # negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, + # timestep=current_timestep, + # guidance_scale=1, + # num_images_per_prompt=self.train_config.batch_size, + # num_inference_steps=self.train_config.max_denoising_steps, + # ) + + negative_pos_noise_prediction = self.sd.pipeline.predict_noise( + latents=denoised_latents, + prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, + negative_prompt_embeds=prompt_pair.positive_target.text_embeds, + pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, + negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, + timestep=current_timestep, + guidance_scale=1, + num_images_per_prompt=self.train_config.batch_size, + num_inference_steps=1000, + ) + + self.network.multiplier = multiplier * -1.0 + + positive_neg_noise_prediction = self.sd.pipeline.predict_noise( + latents=denoised_latents, + prompt_embeds=prompt_pair.positive_target_with_neutral.text_embeds, + negative_prompt_embeds=prompt_pair.negative_target.text_embeds, + pooled_prompt_embeds=prompt_pair.positive_target_with_neutral.pooled_embeds, + negative_pooled_prompt_embeds=prompt_pair.negative_target.pooled_embeds, + timestep=current_timestep, + guidance_scale=1, + num_images_per_prompt=self.train_config.batch_size, + num_inference_steps=1000, + ) + + # negative_neg_noise_prediction = self.sd.pipeline.predict_noise( + # latents=denoised_latents, + # prompt_embeds=prompt_pair.negative_target_with_neutral.text_embeds, + # negative_prompt_embeds=prompt_pair.positive_target.text_embeds, + # pooled_prompt_embeds=prompt_pair.negative_target_with_neutral.pooled_embeds, + # negative_pooled_prompt_embeds=prompt_pair.positive_target.pooled_embeds, + # timestep=current_timestep, + # guidance_scale=1, + # num_images_per_prompt=self.train_config.batch_size, + # num_inference_steps=self.train_config.max_denoising_steps, + # ) + + self.network.multiplier = 1.0 + + neutral_noise_prediction.requires_grad = False + # positive_pos_noise_prediction.requires_grad = False + # negative_neg_noise_prediction.requires_grad = False + + # calculate loss + loss_shrink_pos_neg = loss_function( + negative_pos_noise_prediction, + neutral_noise_prediction, + ) + + loss_shrink_neg_pos = loss_function( + positive_neg_noise_prediction, + negative_pos_noise_prediction, + ) + + loss = loss_shrink_pos_neg + loss_shrink_neg_pos loss_float = loss.item() @@ -384,12 +412,14 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler.step() del ( - positive_latents, - neutral_latents, - unconditional_latents, - target_latents, + denoised_latents, + positive_neg_noise_prediction, + negative_pos_noise_prediction, + neutral_noise_prediction, latents, ) + # move back to cpu + prompt_pair.to("cpu") flush() # reset network @@ -398,9 +428,6 @@ class TrainSliderProcess(BaseSDTrainProcess): loss_dict = OrderedDict( {'loss': loss_float}, ) - if anchor_loss is not None: - loss_dict['sl_l'] = loss_slide - loss_dict['an_l'] = anchor_loss.item() return loss_dict # end hook_train_loop diff --git a/jobs/process/TrainSliderProcessOld.py b/jobs/process/TrainSliderProcessOld.py new file mode 100644 index 00000000..a33f6314 --- /dev/null +++ b/jobs/process/TrainSliderProcessOld.py @@ -0,0 +1,406 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os +from typing import Optional + +from toolkit.config_modules import SliderConfig +from toolkit.paths import REPOS_ROOT +import sys + +from toolkit.stable_diffusion_model import PromptEmbeds + +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from leco import train_util, model_util +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +class ACTION_TYPES_SLIDER: + ERASE_NEGATIVE = 0 + ENHANCE_NEGATIVE = 1 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class EncodedPromptPair: + def __init__( + self, + target_class, + positive, + negative, + neutral, + width=512, + height=512, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=1.0, + weight=1.0 + ): + self.target_class = target_class + self.positive = positive + self.negative = negative + self.neutral = neutral + self.width = width + self.height = height + self.action: int = action + self.multiplier = multiplier + self.weight = weight + + +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class EncodedAnchor: + def __init__( + self, + prompt, + neg_prompt, + multiplier=1.0 + ): + self.prompt = prompt + self.neg_prompt = neg_prompt + self.multiplier = multiplier + + +class TrainSliderProcessOld(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) + self.prompt_cache = PromptEmbedsCache() + self.prompt_pairs: list[EncodedPromptPair] = [] + self.anchor_pairs: list[EncodedAnchor] = [] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + cache = PromptEmbedsCache() + prompt_pairs: list[EncodedPromptPair] = [] + + # get encoded latents for our prompts + with torch.no_grad(): + neutral = "" + for target in self.slider_config.targets: + # build the cache + for prompt in [ + target.target_class, + target.positive, + target.negative, + neutral # empty neutral + ]: + if cache[prompt] is None: + cache[prompt] = self.sd.encode_prompt(prompt) + for resolution in self.slider_config.resolutions: + width, height = resolution + only_erase = len(target.positive.strip()) == 0 + only_enhance = len(target.negative.strip()) == 0 + + both = not only_erase and not only_enhance + + if only_erase and only_enhance: + raise ValueError("target must have at least one of positive or negative or both") + # for slider we need to have an enhancer, an eraser, and then + # an inverse with negative weights to balance the network + # if we don't do this, we will get different contrast and focus. + # we only perform actions of enhancing and erasing on the negative + # todo work on way to do all of this in one shot + + if both or only_erase: + prompt_pairs += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both or only_enhance: + prompt_pairs += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both: + prompt_pairs += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + prompt_pairs += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + + # setup anchors + anchor_pairs = [] + for anchor in self.slider_config.anchors: + # build the cache + for prompt in [ + anchor.prompt, + anchor.neg_prompt # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = self.sd.encode_prompt(prompt) + + anchor_pairs += [ + EncodedAnchor( + prompt=cache[anchor.prompt], + neg_prompt=cache[anchor.neg_prompt], + multiplier=anchor.multiplier + ) + ] + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + self.prompt_pairs = prompt_pairs + self.anchor_pairs = anchor_pairs + flush() + # end hook_before_train_loop + + def hook_train_loop(self): + dtype = get_torch_dtype(self.train_config.dtype) + + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + + height = prompt_pair.height + width = prompt_pair.width + target_class = prompt_pair.target_class + neutral = prompt_pair.neutral + negative = prompt_pair.negative + positive = prompt_pair.positive + weight = prompt_pair.weight + multiplier = prompt_pair.multiplier + + unet = self.sd.unet + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + + def get_noise_pred(p, n, gs, cts, dn): + return self.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # unconditional + n, # positive + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + ) + + # set network multiplier + self.network.multiplier = multiplier + + with torch.no_grad(): + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + self.optimizer.zero_grad() + + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() + + # get noise + noise = self.get_latent_noise( + pixel_height=height, + pixel_width=width, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + with self.network: + assert self.network.is_active + self.network.multiplier = multiplier + denoised_latents = self.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + positive, # unconditional + target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + ] + + positive_latents = get_noise_pred( + positive, negative, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + neutral_latents = get_noise_pred( + positive, neutral, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + unconditional_latents = get_noise_pred( + positive, positive, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + anchor_loss = None + if len(self.anchor_pairs) > 0: + # get a random anchor pair + anchor: EncodedAnchor = self.anchor_pairs[ + torch.randint(0, len(self.anchor_pairs), (1,)).item() + ] + with torch.no_grad(): + anchor_target_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + with self.network: + # anchor whatever weight prompt pair is using + pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 + self.network.multiplier = anchor.multiplier * pos_nem_mult + + anchor_pred_noise = get_noise_pred( + anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + self.network.multiplier = prompt_pair.multiplier + + with self.network: + self.network.multiplier = prompt_pair.multiplier + target_latents = get_noise_pred( + positive, target_class, 1, current_timestep, denoised_latents + ).to("cpu", dtype=torch.float32) + + # if self.logging_config.verbose: + # self.print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + if len(self.anchor_pairs) > 0: + anchor_target_noise.requires_grad = False + anchor_loss = loss_function( + anchor_target_noise, + anchor_pred_noise, + ) + erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents - unconditional_latents) + + offset_neutral = neutral_latents + if erase: + offset_neutral -= offset + else: + # enhance + offset_neutral += offset + + loss = loss_function( + target_latents, + offset_neutral, + ) * weight + + loss_slide = loss.item() + + if anchor_loss is not None: + loss += anchor_loss + + loss_float = loss.item() + + loss = loss.to(self.device_torch) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + if anchor_loss is not None: + loss_dict['sl_l'] = loss_slide + loss_dict['an_l'] = anchor_loss.item() + + return loss_dict + # end hook_train_loop diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 6329a213..e4fc21bc 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -6,5 +6,6 @@ from .BaseTrainProcess import BaseTrainProcess from .TrainVAEProcess import TrainVAEProcess from .BaseMergeProcess import BaseMergeProcess from .TrainSliderProcess import TrainSliderProcess +from .TrainSliderProcessOld import TrainSliderProcessOld from .TrainLoRAHack import TrainLoRAHack from .TrainSDRescaleProcess import TrainSDRescaleProcess \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5e5c3623..03bf3487 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -99,3 +99,5 @@ class SliderConfig: anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] self.anchors: List[SliderConfigAnchors] = anchors self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) + self.prompt_file: str = kwargs.get('prompt_file', '') + self.prompt_tensors: str = kwargs.get('prompt_tensors', '') diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a763a988..8b465249 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -30,10 +30,10 @@ class PromptEmbeds: self.text_embeds = args self.pooled_embeds = None - def to(self, **kwargs): - self.text_embeds = self.text_embeds.to(**kwargs) + def to(self, *args, **kwargs): + self.text_embeds = self.text_embeds.to(*args, **kwargs) if self.pooled_embeds is not None: - self.pooled_embeds = self.pooled_embeds.to(**kwargs) + self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) return self