mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
536 lines
20 KiB
Python
536 lines
20 KiB
Python
import os
|
|
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
|
|
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
import random
|
|
|
|
from toolkit.train_tools import get_torch_dtype
|
|
import itertools
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.config_modules import SliderTargetConfig
|
|
|
|
|
|
class ACTION_TYPES_SLIDER:
|
|
ERASE_NEGATIVE = 0
|
|
ENHANCE_NEGATIVE = 1
|
|
|
|
|
|
class PromptEmbeds:
|
|
text_embeds: torch.Tensor
|
|
pooled_embeds: Union[torch.Tensor, None]
|
|
|
|
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
|
if isinstance(args, list) or isinstance(args, tuple):
|
|
# xl
|
|
self.text_embeds = args[0]
|
|
self.pooled_embeds = args[1]
|
|
else:
|
|
# sdv1.x, sdv2.x
|
|
self.text_embeds = args
|
|
self.pooled_embeds = None
|
|
|
|
def to(self, *args, **kwargs):
|
|
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
|
if self.pooled_embeds is not None:
|
|
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
|
return self
|
|
|
|
def detach(self):
|
|
self.text_embeds = self.text_embeds.detach()
|
|
if self.pooled_embeds is not None:
|
|
self.pooled_embeds = self.pooled_embeds.detach()
|
|
return self
|
|
|
|
def clone(self):
|
|
if self.pooled_embeds is not None:
|
|
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
|
|
else:
|
|
return PromptEmbeds(self.text_embeds.clone())
|
|
|
|
|
|
class EncodedPromptPair:
|
|
def __init__(
|
|
self,
|
|
target_class,
|
|
target_class_with_neutral,
|
|
positive_target,
|
|
positive_target_with_neutral,
|
|
negative_target,
|
|
negative_target_with_neutral,
|
|
neutral,
|
|
empty_prompt,
|
|
both_targets,
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
action_list=None,
|
|
multiplier=1.0,
|
|
multiplier_list=None,
|
|
weight=1.0,
|
|
target: 'SliderTargetConfig' = None,
|
|
):
|
|
self.target_class: PromptEmbeds = target_class
|
|
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
|
|
self.positive_target: PromptEmbeds = positive_target
|
|
self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral
|
|
self.negative_target: PromptEmbeds = negative_target
|
|
self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral
|
|
self.neutral: PromptEmbeds = neutral
|
|
self.empty_prompt: PromptEmbeds = empty_prompt
|
|
self.both_targets: PromptEmbeds = both_targets
|
|
self.multiplier: float = multiplier
|
|
self.target: 'SliderTargetConfig' = target
|
|
if multiplier_list is not None:
|
|
self.multiplier_list: list[float] = multiplier_list
|
|
else:
|
|
self.multiplier_list: list[float] = [multiplier]
|
|
self.action: int = action
|
|
if action_list is not None:
|
|
self.action_list: list[int] = action_list
|
|
else:
|
|
self.action_list: list[int] = [action]
|
|
self.weight: float = weight
|
|
|
|
# simulate torch to for tensors
|
|
def to(self, *args, **kwargs):
|
|
self.target_class = self.target_class.to(*args, **kwargs)
|
|
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
|
|
self.positive_target = self.positive_target.to(*args, **kwargs)
|
|
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
|
self.negative_target = self.negative_target.to(*args, **kwargs)
|
|
self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
|
|
self.neutral = self.neutral.to(*args, **kwargs)
|
|
self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
|
|
self.both_targets = self.both_targets.to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
|
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
|
pooled_embeds = None
|
|
if prompt_embeds[0].pooled_embeds is not None:
|
|
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
|
return PromptEmbeds([text_embeds, pooled_embeds])
|
|
|
|
|
|
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
|
weight = prompt_pairs[0].weight
|
|
target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs])
|
|
target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs])
|
|
positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs])
|
|
positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs])
|
|
negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs])
|
|
negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs])
|
|
neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs])
|
|
empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs])
|
|
both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs])
|
|
# combine all the lists
|
|
action_list = []
|
|
multiplier_list = []
|
|
weight_list = []
|
|
for p in prompt_pairs:
|
|
action_list += p.action_list
|
|
multiplier_list += p.multiplier_list
|
|
return EncodedPromptPair(
|
|
target_class=target_class,
|
|
target_class_with_neutral=target_class_with_neutral,
|
|
positive_target=positive_target,
|
|
positive_target_with_neutral=positive_target_with_neutral,
|
|
negative_target=negative_target,
|
|
negative_target_with_neutral=negative_target_with_neutral,
|
|
neutral=neutral,
|
|
empty_prompt=empty_prompt,
|
|
both_targets=both_targets,
|
|
action_list=action_list,
|
|
multiplier_list=multiplier_list,
|
|
weight=weight,
|
|
target=prompt_pairs[0].target
|
|
)
|
|
|
|
|
|
def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]:
|
|
if num_parts is None:
|
|
# use batch size
|
|
num_parts = concatenated.text_embeds.shape[0]
|
|
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
|
|
|
|
if concatenated.pooled_embeds is not None:
|
|
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)
|
|
else:
|
|
pooled_embeds_splits = [None] * num_parts
|
|
|
|
prompt_embeds_list = [
|
|
PromptEmbeds([text, pooled])
|
|
for text, pooled in zip(text_embeds_splits, pooled_embeds_splits)
|
|
]
|
|
|
|
return prompt_embeds_list
|
|
|
|
|
|
def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]:
|
|
target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds)
|
|
target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds)
|
|
positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds)
|
|
positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds)
|
|
negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds)
|
|
negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds)
|
|
neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds)
|
|
empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds)
|
|
both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds)
|
|
|
|
prompt_pairs = []
|
|
for i in range(len(target_class_splits)):
|
|
action_list_split = concatenated.action_list[i::len(target_class_splits)]
|
|
multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)]
|
|
|
|
prompt_pair = EncodedPromptPair(
|
|
target_class=target_class_splits[i],
|
|
target_class_with_neutral=target_class_with_neutral_splits[i],
|
|
positive_target=positive_target_splits[i],
|
|
positive_target_with_neutral=positive_target_with_neutral_splits[i],
|
|
negative_target=negative_target_splits[i],
|
|
negative_target_with_neutral=negative_target_with_neutral_splits[i],
|
|
neutral=neutral_splits[i],
|
|
empty_prompt=empty_prompt_splits[i],
|
|
both_targets=both_targets_splits[i],
|
|
action_list=action_list_split,
|
|
multiplier_list=multiplier_list_split,
|
|
weight=concatenated.weight,
|
|
target=concatenated.target
|
|
)
|
|
prompt_pairs.append(prompt_pair)
|
|
|
|
return prompt_pairs
|
|
|
|
|
|
class PromptEmbedsCache:
|
|
prompts: dict[str, PromptEmbeds] = {}
|
|
|
|
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
|
self.prompts[__name] = __value
|
|
|
|
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
|
if __name in self.prompts:
|
|
return self.prompts[__name]
|
|
else:
|
|
return None
|
|
|
|
|
|
class EncodedAnchor:
|
|
def __init__(
|
|
self,
|
|
prompt,
|
|
neg_prompt,
|
|
multiplier=1.0,
|
|
multiplier_list=None
|
|
):
|
|
self.prompt = prompt
|
|
self.neg_prompt = neg_prompt
|
|
self.multiplier = multiplier
|
|
|
|
if multiplier_list is not None:
|
|
self.multiplier_list: list[float] = multiplier_list
|
|
else:
|
|
self.multiplier_list: list[float] = [multiplier]
|
|
|
|
def to(self, *args, **kwargs):
|
|
self.prompt = self.prompt.to(*args, **kwargs)
|
|
self.neg_prompt = self.neg_prompt.to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def concat_anchors(anchors: list[EncodedAnchor]):
|
|
prompt = concat_prompt_embeds([a.prompt for a in anchors])
|
|
neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors])
|
|
return EncodedAnchor(
|
|
prompt=prompt,
|
|
neg_prompt=neg_prompt,
|
|
multiplier_list=[a.multiplier for a in anchors]
|
|
)
|
|
|
|
|
|
def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]:
|
|
prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors)
|
|
neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors)
|
|
multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors)
|
|
|
|
anchors = []
|
|
for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits):
|
|
anchor = EncodedAnchor(
|
|
prompt=prompt,
|
|
neg_prompt=neg_prompt,
|
|
multiplier=multiplier.tolist()
|
|
)
|
|
anchors.append(anchor)
|
|
|
|
return anchors
|
|
|
|
|
|
def get_permutations(s):
|
|
# Split the string by comma
|
|
phrases = [phrase.strip() for phrase in s.split(',')]
|
|
|
|
# remove empty strings
|
|
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
|
|
|
# Get all permutations
|
|
permutations = list(itertools.permutations(phrases))
|
|
|
|
# Convert the tuples back to comma separated strings
|
|
return [', '.join(permutation) for permutation in permutations]
|
|
|
|
|
|
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
|
from toolkit.config_modules import SliderTargetConfig
|
|
pos_permutations = get_permutations(target.positive)
|
|
neg_permutations = get_permutations(target.negative)
|
|
|
|
permutations = []
|
|
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
|
permutations.append(
|
|
SliderTargetConfig(
|
|
target_class=target.target_class,
|
|
positive=pos,
|
|
negative=neg,
|
|
multiplier=target.multiplier,
|
|
weight=target.weight
|
|
)
|
|
)
|
|
|
|
# shuffle the list
|
|
random.shuffle(permutations)
|
|
|
|
if len(permutations) > max_permutations:
|
|
permutations = permutations[:max_permutations]
|
|
|
|
return permutations
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode_prompts_to_cache(
|
|
prompt_list: list[str],
|
|
sd: "StableDiffusion",
|
|
cache: Optional[PromptEmbedsCache] = None,
|
|
prompt_tensor_file: Optional[str] = None,
|
|
) -> PromptEmbedsCache:
|
|
# TODO: add support for larger prompts
|
|
if cache is None:
|
|
cache = PromptEmbedsCache()
|
|
|
|
if prompt_tensor_file is not None:
|
|
# check to see if it exists
|
|
if os.path.exists(prompt_tensor_file):
|
|
# load it.
|
|
print(f"Loading prompt tensors from {prompt_tensor_file}")
|
|
prompt_tensors = load_file(prompt_tensor_file, device='cpu')
|
|
# add them to the cache
|
|
for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
|
|
if prompt_txt.startswith("te:"):
|
|
prompt = prompt_txt[3:]
|
|
# text_embeds
|
|
text_embeds = prompt_tensor
|
|
pooled_embeds = None
|
|
# find pool embeds
|
|
if f"pe:{prompt}" in prompt_tensors:
|
|
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
|
|
|
|
# make it
|
|
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
|
|
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
|
|
|
|
if len(cache.prompts) == 0:
|
|
print("Prompt tensors not found. Encoding prompts..")
|
|
empty_prompt = ""
|
|
# encode empty_prompt
|
|
cache[empty_prompt] = sd.encode_prompt(empty_prompt)
|
|
|
|
for p in tqdm(prompt_list, desc="Encoding prompts", leave=False):
|
|
# build the cache
|
|
if cache[p] is None:
|
|
cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16)
|
|
|
|
# should we shard? It can get large
|
|
if prompt_tensor_file:
|
|
print(f"Saving prompt tensors to {prompt_tensor_file}")
|
|
state_dict = {}
|
|
for prompt_txt, prompt_embeds in cache.prompts.items():
|
|
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to(
|
|
"cpu", dtype=get_torch_dtype('fp16')
|
|
)
|
|
if prompt_embeds.pooled_embeds is not None:
|
|
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to(
|
|
"cpu",
|
|
dtype=get_torch_dtype('fp16')
|
|
)
|
|
save_file(state_dict, prompt_tensor_file)
|
|
|
|
return cache
|
|
|
|
|
|
@torch.no_grad()
|
|
def build_prompt_pair_batch_from_cache(
|
|
cache: PromptEmbedsCache,
|
|
target: 'SliderTargetConfig',
|
|
neutral: Optional[str] = '',
|
|
) -> list[EncodedPromptPair]:
|
|
erase_negative = len(target.positive.strip()) == 0
|
|
enhance_positive = len(target.negative.strip()) == 0
|
|
|
|
both = not erase_negative and not enhance_positive
|
|
|
|
prompt_pair_batch = []
|
|
|
|
if both or erase_negative:
|
|
# print("Encoding erase negative")
|
|
prompt_pair_batch += [
|
|
# erase standard
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.positive}"],
|
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
negative_target=cache[f"{target.negative}"],
|
|
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
multiplier=target.multiplier,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
weight=target.weight,
|
|
target=target
|
|
),
|
|
]
|
|
if both or enhance_positive:
|
|
# print("Encoding enhance positive")
|
|
prompt_pair_batch += [
|
|
# enhance standard, swap pos neg
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.negative}"],
|
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
negative_target=cache[f"{target.positive}"],
|
|
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
|
multiplier=target.multiplier,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
weight=target.weight,
|
|
target=target
|
|
),
|
|
]
|
|
if both or enhance_positive:
|
|
# print("Encoding erase positive (inverse)")
|
|
prompt_pair_batch += [
|
|
# erase inverted
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.negative}"],
|
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
negative_target=cache[f"{target.positive}"],
|
|
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
multiplier=target.multiplier * -1.0,
|
|
weight=target.weight,
|
|
target=target
|
|
),
|
|
]
|
|
if both or erase_negative:
|
|
# print("Encoding enhance negative (inverse)")
|
|
prompt_pair_batch += [
|
|
# enhance inverted
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.positive}"],
|
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
negative_target=cache[f"{target.negative}"],
|
|
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
|
empty_prompt=cache[""],
|
|
multiplier=target.multiplier * -1.0,
|
|
weight=target.weight,
|
|
target=target
|
|
),
|
|
]
|
|
|
|
return prompt_pair_batch
|
|
|
|
|
|
def build_latent_image_batch_for_prompt_pair(
|
|
pos_latent,
|
|
neg_latent,
|
|
prompt_pair: EncodedPromptPair,
|
|
prompt_chunk_size
|
|
):
|
|
erase_negative = len(prompt_pair.target.positive.strip()) == 0
|
|
enhance_positive = len(prompt_pair.target.negative.strip()) == 0
|
|
both = not erase_negative and not enhance_positive
|
|
|
|
prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size)
|
|
if both and len(prompt_pair_chunks) != 4:
|
|
raise Exception("Invalid prompt pair chunks")
|
|
if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2:
|
|
raise Exception("Invalid prompt pair chunks")
|
|
|
|
latent_list = []
|
|
|
|
if both or erase_negative:
|
|
latent_list.append(pos_latent)
|
|
if both or enhance_positive:
|
|
latent_list.append(pos_latent)
|
|
if both or enhance_positive:
|
|
latent_list.append(neg_latent)
|
|
if both or erase_negative:
|
|
latent_list.append(neg_latent)
|
|
|
|
return torch.cat(latent_list, dim=0)
|
|
|
|
|
|
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
|
if trigger is None:
|
|
# process as empty string to remove any [trigger] tokens
|
|
trigger = ''
|
|
output_prompt = prompt
|
|
default_replacements = ["[name]", "[trigger]"]
|
|
|
|
replace_with = trigger
|
|
if to_replace_list is None:
|
|
to_replace_list = default_replacements
|
|
else:
|
|
to_replace_list += default_replacements
|
|
|
|
# remove duplicates
|
|
to_replace_list = list(set(to_replace_list))
|
|
|
|
# replace them all
|
|
for to_replace in to_replace_list:
|
|
# replace it
|
|
output_prompt = output_prompt.replace(to_replace, replace_with)
|
|
|
|
if trigger.strip() != "":
|
|
# see how many times replace_with is in the prompt
|
|
num_instances = output_prompt.count(replace_with)
|
|
|
|
if num_instances == 0 and add_if_not_present:
|
|
# add it to the beginning of the prompt
|
|
output_prompt = replace_with + " " + output_prompt
|
|
|
|
if num_instances > 1:
|
|
print(
|
|
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
|
|
|
return output_prompt
|