import os from typing import Optional, TYPE_CHECKING, List import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from toolkit.stable_diffusion_model import PromptEmbeds from toolkit.train_tools import get_torch_dtype class ACTION_TYPES_SLIDER: ERASE_NEGATIVE = 0 ENHANCE_NEGATIVE = 1 class EncodedPromptPair: def __init__( self, target_class, target_class_with_neutral, positive_target, positive_target_with_neutral, negative_target, negative_target_with_neutral, neutral, empty_prompt, both_targets, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, action_list=None, multiplier=1.0, multiplier_list=None, weight=1.0 ): self.target_class: PromptEmbeds = target_class self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral self.positive_target: PromptEmbeds = positive_target self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral self.negative_target: PromptEmbeds = negative_target self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral self.neutral: PromptEmbeds = neutral self.empty_prompt: PromptEmbeds = empty_prompt self.both_targets: PromptEmbeds = both_targets self.multiplier: float = multiplier if multiplier_list is not None: self.multiplier_list: list[float] = multiplier_list else: self.multiplier_list: list[float] = [multiplier] self.action: int = action if action_list is not None: self.action_list: list[int] = action_list else: self.action_list: list[int] = [action] self.weight: float = weight # simulate torch to for tensors def to(self, *args, **kwargs): self.target_class = self.target_class.to(*args, **kwargs) self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs) self.positive_target = self.positive_target.to(*args, **kwargs) self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) self.negative_target = self.negative_target.to(*args, **kwargs) self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs) self.neutral = self.neutral.to(*args, **kwargs) self.empty_prompt = self.empty_prompt.to(*args, **kwargs) self.both_targets = self.both_targets.to(*args, **kwargs) return self def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) pooled_embeds = None if prompt_embeds[0].pooled_embeds is not None: pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) return PromptEmbeds([text_embeds, pooled_embeds]) def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): weight = prompt_pairs[0].weight target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs]) target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs]) positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs]) positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs]) negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs]) negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs]) neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs]) empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs]) both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs]) # combine all the lists action_list = [] multiplier_list = [] weight_list = [] for p in prompt_pairs: action_list += p.action_list multiplier_list += p.multiplier_list return EncodedPromptPair( target_class=target_class, target_class_with_neutral=target_class_with_neutral, positive_target=positive_target, positive_target_with_neutral=positive_target_with_neutral, negative_target=negative_target, negative_target_with_neutral=negative_target_with_neutral, neutral=neutral, empty_prompt=empty_prompt, both_targets=both_targets, action_list=action_list, multiplier_list=multiplier_list, weight=weight ) def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]: if num_parts is None: # use batch size num_parts = concatenated.text_embeds.shape[0] text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0) if concatenated.pooled_embeds is not None: pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0) else: pooled_embeds_splits = [None] * num_parts prompt_embeds_list = [ PromptEmbeds([text, pooled]) for text, pooled in zip(text_embeds_splits, pooled_embeds_splits) ] return prompt_embeds_list def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]: target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds) target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds) positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds) positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds) negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds) negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds) neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds) empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds) both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds) prompt_pairs = [] for i in range(len(target_class_splits)): action_list_split = concatenated.action_list[i::len(target_class_splits)] multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)] prompt_pair = EncodedPromptPair( target_class=target_class_splits[i], target_class_with_neutral=target_class_with_neutral_splits[i], positive_target=positive_target_splits[i], positive_target_with_neutral=positive_target_with_neutral_splits[i], negative_target=negative_target_splits[i], negative_target_with_neutral=negative_target_with_neutral_splits[i], neutral=neutral_splits[i], empty_prompt=empty_prompt_splits[i], both_targets=both_targets_splits[i], action_list=action_list_split, multiplier_list=multiplier_list_split, weight=concatenated.weight ) prompt_pairs.append(prompt_pair) return prompt_pairs class PromptEmbedsCache: prompts: dict[str, PromptEmbeds] = {} def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: self.prompts[__name] = __value def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: if __name in self.prompts: return self.prompts[__name] else: return None class EncodedAnchor: def __init__( self, prompt, neg_prompt, multiplier=1.0, multiplier_list=None ): self.prompt = prompt self.neg_prompt = neg_prompt self.multiplier = multiplier if multiplier_list is not None: self.multiplier_list: list[float] = multiplier_list else: self.multiplier_list: list[float] = [multiplier] def to(self, *args, **kwargs): self.prompt = self.prompt.to(*args, **kwargs) self.neg_prompt = self.neg_prompt.to(*args, **kwargs) return self def concat_anchors(anchors: list[EncodedAnchor]): prompt = concat_prompt_embeds([a.prompt for a in anchors]) neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors]) return EncodedAnchor( prompt=prompt, neg_prompt=neg_prompt, multiplier_list=[a.multiplier for a in anchors] ) def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]: prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors) neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors) multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors) anchors = [] for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits): anchor = EncodedAnchor( prompt=prompt, neg_prompt=neg_prompt, multiplier=multiplier.tolist() ) anchors.append(anchor) return anchors if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @torch.no_grad() def encode_prompts_to_cache( prompt_list: list[str], sd: "StableDiffusion", cache: Optional[PromptEmbedsCache] = None, prompt_tensor_file: Optional[str] = None, ) -> PromptEmbedsCache: # TODO: add support for larger prompts if cache is None: cache = PromptEmbedsCache() if prompt_tensor_file is not None: # check to see if it exists if os.path.exists(prompt_tensor_file): # load it. print(f"Loading prompt tensors from {prompt_tensor_file}") prompt_tensors = load_file(prompt_tensor_file, device='cpu') # add them to the cache for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False): if prompt_txt.startswith("te:"): prompt = prompt_txt[3:] # text_embeds text_embeds = prompt_tensor pooled_embeds = None # find pool embeds if f"pe:{prompt}" in prompt_tensors: pooled_embeds = prompt_tensors[f"pe:{prompt}"] # make it prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) if len(cache.prompts) == 0: print("Prompt tensors not found. Encoding prompts..") empty_prompt = "" # encode empty_prompt cache[empty_prompt] = sd.encode_prompt(empty_prompt) for p in tqdm(prompt_list, desc="Encoding prompts", leave=False): # build the cache if cache[p] is None: cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16) # should we shard? It can get large if prompt_tensor_file: print(f"Saving prompt tensors to {prompt_tensor_file}") state_dict = {} for prompt_txt, prompt_embeds in cache.prompts.items(): state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to( "cpu", dtype=get_torch_dtype('fp16') ) if prompt_embeds.pooled_embeds is not None: state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to( "cpu", dtype=get_torch_dtype('fp16') ) save_file(state_dict, prompt_tensor_file) return cache if TYPE_CHECKING: from toolkit.config_modules import SliderTargetConfig @torch.no_grad() def build_prompt_pair_batch_from_cache( cache: PromptEmbedsCache, target: 'SliderTargetConfig', neutral: Optional[str] = '', ) -> list[EncodedPromptPair]: erase_negative = len(target.positive.strip()) == 0 enhance_positive = len(target.negative.strip()) == 0 both = not erase_negative and not enhance_positive prompt_pair_batch = [] if both or erase_negative: # print("Encoding erase negative") prompt_pair_batch += [ # erase standard EncodedPromptPair( target_class=cache[target.target_class], target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.positive}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"], negative_target=cache[f"{target.negative}"], negative_target_with_neutral=cache[f"{target.negative} {neutral}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, multiplier=target.multiplier, both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], weight=target.weight ), ] if both or enhance_positive: # print("Encoding enhance positive") prompt_pair_batch += [ # enhance standard, swap pos neg EncodedPromptPair( target_class=cache[target.target_class], target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.negative}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"], negative_target=cache[f"{target.positive}"], negative_target_with_neutral=cache[f"{target.positive} {neutral}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, multiplier=target.multiplier, both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], weight=target.weight ), ] if both or enhance_positive: # print("Encoding erase positive (inverse)") prompt_pair_batch += [ # erase inverted EncodedPromptPair( target_class=cache[target.target_class], target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.negative}"], positive_target_with_neutral=cache[f"{target.negative} {neutral}"], negative_target=cache[f"{target.positive}"], negative_target_with_neutral=cache[f"{target.positive} {neutral}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], multiplier=target.multiplier * -1.0, weight=target.weight ), ] if both or erase_negative: # print("Encoding enhance negative (inverse)") prompt_pair_batch += [ # enhance inverted EncodedPromptPair( target_class=cache[target.target_class], target_class_with_neutral=cache[f"{target.target_class} {neutral}"], positive_target=cache[f"{target.positive}"], positive_target_with_neutral=cache[f"{target.positive} {neutral}"], negative_target=cache[f"{target.negative}"], negative_target_with_neutral=cache[f"{target.negative} {neutral}"], both_targets=cache[f"{target.positive} {target.negative}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, empty_prompt=cache[""], multiplier=target.multiplier * -1.0, weight=target.weight ), ] return prompt_pair_batch