Files
ai-toolkit/toolkit/prompt_utils.py
2023-08-19 07:57:30 -06:00

423 lines
17 KiB
Python

import os
from typing import Optional, TYPE_CHECKING, List
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import itertools
if TYPE_CHECKING:
from toolkit.config_modules import SliderTargetConfig
class ACTION_TYPES_SLIDER:
ERASE_NEGATIVE = 0
ENHANCE_NEGATIVE = 1
class EncodedPromptPair:
def __init__(
self,
target_class,
target_class_with_neutral,
positive_target,
positive_target_with_neutral,
negative_target,
negative_target_with_neutral,
neutral,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
action_list=None,
multiplier=1.0,
multiplier_list=None,
weight=1.0
):
self.target_class: PromptEmbeds = target_class
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
self.positive_target: PromptEmbeds = positive_target
self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral
self.negative_target: PromptEmbeds = negative_target
self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral
self.neutral: PromptEmbeds = neutral
self.empty_prompt: PromptEmbeds = empty_prompt
self.both_targets: PromptEmbeds = both_targets
self.multiplier: float = multiplier
if multiplier_list is not None:
self.multiplier_list: list[float] = multiplier_list
else:
self.multiplier_list: list[float] = [multiplier]
self.action: int = action
if action_list is not None:
self.action_list: list[int] = action_list
else:
self.action_list: list[int] = [action]
self.weight: float = weight
# simulate torch to for tensors
def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs)
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs)
self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
self.neutral = self.neutral.to(*args, **kwargs)
self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
weight = prompt_pairs[0].weight
target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs])
target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs])
positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs])
positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs])
negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs])
negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs])
neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs])
empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs])
both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs])
# combine all the lists
action_list = []
multiplier_list = []
weight_list = []
for p in prompt_pairs:
action_list += p.action_list
multiplier_list += p.multiplier_list
return EncodedPromptPair(
target_class=target_class,
target_class_with_neutral=target_class_with_neutral,
positive_target=positive_target,
positive_target_with_neutral=positive_target_with_neutral,
negative_target=negative_target,
negative_target_with_neutral=negative_target_with_neutral,
neutral=neutral,
empty_prompt=empty_prompt,
both_targets=both_targets,
action_list=action_list,
multiplier_list=multiplier_list,
weight=weight
)
def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]:
if num_parts is None:
# use batch size
num_parts = concatenated.text_embeds.shape[0]
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if concatenated.pooled_embeds is not None:
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)
else:
pooled_embeds_splits = [None] * num_parts
prompt_embeds_list = [
PromptEmbeds([text, pooled])
for text, pooled in zip(text_embeds_splits, pooled_embeds_splits)
]
return prompt_embeds_list
def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]:
target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds)
target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds)
positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds)
positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds)
negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds)
negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds)
neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds)
empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds)
both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds)
prompt_pairs = []
for i in range(len(target_class_splits)):
action_list_split = concatenated.action_list[i::len(target_class_splits)]
multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)]
prompt_pair = EncodedPromptPair(
target_class=target_class_splits[i],
target_class_with_neutral=target_class_with_neutral_splits[i],
positive_target=positive_target_splits[i],
positive_target_with_neutral=positive_target_with_neutral_splits[i],
negative_target=negative_target_splits[i],
negative_target_with_neutral=negative_target_with_neutral_splits[i],
neutral=neutral_splits[i],
empty_prompt=empty_prompt_splits[i],
both_targets=both_targets_splits[i],
action_list=action_list_split,
multiplier_list=multiplier_list_split,
weight=concatenated.weight
)
prompt_pairs.append(prompt_pair)
return prompt_pairs
class PromptEmbedsCache:
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0,
multiplier_list=None
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier
if multiplier_list is not None:
self.multiplier_list: list[float] = multiplier_list
else:
self.multiplier_list: list[float] = [multiplier]
def to(self, *args, **kwargs):
self.prompt = self.prompt.to(*args, **kwargs)
self.neg_prompt = self.neg_prompt.to(*args, **kwargs)
return self
def concat_anchors(anchors: list[EncodedAnchor]):
prompt = concat_prompt_embeds([a.prompt for a in anchors])
neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors])
return EncodedAnchor(
prompt=prompt,
neg_prompt=neg_prompt,
multiplier_list=[a.multiplier for a in anchors]
)
def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]:
prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors)
neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors)
multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors)
anchors = []
for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits):
anchor = EncodedAnchor(
prompt=prompt,
neg_prompt=neg_prompt,
multiplier=multiplier.tolist()
)
anchors.append(anchor)
return anchors
def get_permutations(s):
# Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0]
# Get all permutations
permutations = list(itertools.permutations(phrases))
# Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations]
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations):
permutations.append(
SliderTargetConfig(
target_class=target.target_class,
positive=pos,
negative=neg,
multiplier=target.multiplier,
weight=target.weight
)
)
return permutations
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
@torch.no_grad()
def encode_prompts_to_cache(
prompt_list: list[str],
sd: "StableDiffusion",
cache: Optional[PromptEmbedsCache] = None,
prompt_tensor_file: Optional[str] = None,
) -> PromptEmbedsCache:
# TODO: add support for larger prompts
if cache is None:
cache = PromptEmbedsCache()
if prompt_tensor_file is not None:
# check to see if it exists
if os.path.exists(prompt_tensor_file):
# load it.
print(f"Loading prompt tensors from {prompt_tensor_file}")
prompt_tensors = load_file(prompt_tensor_file, device='cpu')
# add them to the cache
for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
if prompt_txt.startswith("te:"):
prompt = prompt_txt[3:]
# text_embeds
text_embeds = prompt_tensor
pooled_embeds = None
# find pool embeds
if f"pe:{prompt}" in prompt_tensors:
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
# make it
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
if len(cache.prompts) == 0:
print("Prompt tensors not found. Encoding prompts..")
empty_prompt = ""
# encode empty_prompt
cache[empty_prompt] = sd.encode_prompt(empty_prompt)
for p in tqdm(prompt_list, desc="Encoding prompts", leave=False):
# build the cache
if cache[p] is None:
cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16)
# should we shard? It can get large
if prompt_tensor_file:
print(f"Saving prompt tensors to {prompt_tensor_file}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to(
"cpu", dtype=get_torch_dtype('fp16')
)
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to(
"cpu",
dtype=get_torch_dtype('fp16')
)
save_file(state_dict, prompt_tensor_file)
return cache
@torch.no_grad()
def build_prompt_pair_batch_from_cache(
cache: PromptEmbedsCache,
target: 'SliderTargetConfig',
neutral: Optional[str] = '',
) -> list[EncodedPromptPair]:
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
prompt_pair_batch = []
if both or erase_negative:
# print("Encoding erase negative")
prompt_pair_batch += [
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
if both or enhance_positive:
# print("Encoding enhance positive")
prompt_pair_batch += [
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
if both or enhance_positive:
# print("Encoding erase positive (inverse)")
prompt_pair_batch += [
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
if both or erase_negative:
# print("Encoding enhance negative (inverse)")
prompt_pair_batch += [
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
both_targets=cache[f"{target.positive} {target.negative}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
return prompt_pair_batch