mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
389 lines
16 KiB
Python
389 lines
16 KiB
Python
import os
|
|
from typing import Optional, TYPE_CHECKING, List
|
|
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
|
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
from toolkit.train_tools import get_torch_dtype
|
|
|
|
|
|
class ACTION_TYPES_SLIDER:
|
|
ERASE_NEGATIVE = 0
|
|
ENHANCE_NEGATIVE = 1
|
|
|
|
|
|
class EncodedPromptPair:
|
|
def __init__(
|
|
self,
|
|
target_class,
|
|
target_class_with_neutral,
|
|
positive_target,
|
|
positive_target_with_neutral,
|
|
negative_target,
|
|
negative_target_with_neutral,
|
|
neutral,
|
|
empty_prompt,
|
|
both_targets,
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
action_list=None,
|
|
multiplier=1.0,
|
|
multiplier_list=None,
|
|
weight=1.0
|
|
):
|
|
self.target_class: PromptEmbeds = target_class
|
|
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
|
|
self.positive_target: PromptEmbeds = positive_target
|
|
self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral
|
|
self.negative_target: PromptEmbeds = negative_target
|
|
self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral
|
|
self.neutral: PromptEmbeds = neutral
|
|
self.empty_prompt: PromptEmbeds = empty_prompt
|
|
self.both_targets: PromptEmbeds = both_targets
|
|
self.multiplier: float = multiplier
|
|
if multiplier_list is not None:
|
|
self.multiplier_list: list[float] = multiplier_list
|
|
else:
|
|
self.multiplier_list: list[float] = [multiplier]
|
|
self.action: int = action
|
|
if action_list is not None:
|
|
self.action_list: list[int] = action_list
|
|
else:
|
|
self.action_list: list[int] = [action]
|
|
self.weight: float = weight
|
|
|
|
# simulate torch to for tensors
|
|
def to(self, *args, **kwargs):
|
|
self.target_class = self.target_class.to(*args, **kwargs)
|
|
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
|
|
self.positive_target = self.positive_target.to(*args, **kwargs)
|
|
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
|
self.negative_target = self.negative_target.to(*args, **kwargs)
|
|
self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
|
|
self.neutral = self.neutral.to(*args, **kwargs)
|
|
self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
|
|
self.both_targets = self.both_targets.to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
|
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
|
pooled_embeds = None
|
|
if prompt_embeds[0].pooled_embeds is not None:
|
|
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
|
|
return PromptEmbeds([text_embeds, pooled_embeds])
|
|
|
|
|
|
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
|
weight = prompt_pairs[0].weight
|
|
target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs])
|
|
target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs])
|
|
positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs])
|
|
positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs])
|
|
negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs])
|
|
negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs])
|
|
neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs])
|
|
empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs])
|
|
both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs])
|
|
# combine all the lists
|
|
action_list = []
|
|
multiplier_list = []
|
|
weight_list = []
|
|
for p in prompt_pairs:
|
|
action_list += p.action_list
|
|
multiplier_list += p.multiplier_list
|
|
return EncodedPromptPair(
|
|
target_class=target_class,
|
|
target_class_with_neutral=target_class_with_neutral,
|
|
positive_target=positive_target,
|
|
positive_target_with_neutral=positive_target_with_neutral,
|
|
negative_target=negative_target,
|
|
negative_target_with_neutral=negative_target_with_neutral,
|
|
neutral=neutral,
|
|
empty_prompt=empty_prompt,
|
|
both_targets=both_targets,
|
|
action_list=action_list,
|
|
multiplier_list=multiplier_list,
|
|
weight=weight
|
|
)
|
|
|
|
|
|
def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]:
|
|
if num_parts is None:
|
|
# use batch size
|
|
num_parts = concatenated.text_embeds.shape[0]
|
|
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
|
|
|
|
if concatenated.pooled_embeds is not None:
|
|
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)
|
|
else:
|
|
pooled_embeds_splits = [None] * num_parts
|
|
|
|
prompt_embeds_list = [
|
|
PromptEmbeds([text, pooled])
|
|
for text, pooled in zip(text_embeds_splits, pooled_embeds_splits)
|
|
]
|
|
|
|
return prompt_embeds_list
|
|
|
|
|
|
def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]:
|
|
target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds)
|
|
target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds)
|
|
positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds)
|
|
positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds)
|
|
negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds)
|
|
negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds)
|
|
neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds)
|
|
empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds)
|
|
both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds)
|
|
|
|
prompt_pairs = []
|
|
for i in range(len(target_class_splits)):
|
|
action_list_split = concatenated.action_list[i::len(target_class_splits)]
|
|
multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)]
|
|
|
|
prompt_pair = EncodedPromptPair(
|
|
target_class=target_class_splits[i],
|
|
target_class_with_neutral=target_class_with_neutral_splits[i],
|
|
positive_target=positive_target_splits[i],
|
|
positive_target_with_neutral=positive_target_with_neutral_splits[i],
|
|
negative_target=negative_target_splits[i],
|
|
negative_target_with_neutral=negative_target_with_neutral_splits[i],
|
|
neutral=neutral_splits[i],
|
|
empty_prompt=empty_prompt_splits[i],
|
|
both_targets=both_targets_splits[i],
|
|
action_list=action_list_split,
|
|
multiplier_list=multiplier_list_split,
|
|
weight=concatenated.weight
|
|
)
|
|
prompt_pairs.append(prompt_pair)
|
|
|
|
return prompt_pairs
|
|
|
|
|
|
class PromptEmbedsCache:
|
|
prompts: dict[str, PromptEmbeds] = {}
|
|
|
|
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
|
self.prompts[__name] = __value
|
|
|
|
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
|
if __name in self.prompts:
|
|
return self.prompts[__name]
|
|
else:
|
|
return None
|
|
|
|
|
|
class EncodedAnchor:
|
|
def __init__(
|
|
self,
|
|
prompt,
|
|
neg_prompt,
|
|
multiplier=1.0,
|
|
multiplier_list=None
|
|
):
|
|
self.prompt = prompt
|
|
self.neg_prompt = neg_prompt
|
|
self.multiplier = multiplier
|
|
|
|
if multiplier_list is not None:
|
|
self.multiplier_list: list[float] = multiplier_list
|
|
else:
|
|
self.multiplier_list: list[float] = [multiplier]
|
|
|
|
def to(self, *args, **kwargs):
|
|
self.prompt = self.prompt.to(*args, **kwargs)
|
|
self.neg_prompt = self.neg_prompt.to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def concat_anchors(anchors: list[EncodedAnchor]):
|
|
prompt = concat_prompt_embeds([a.prompt for a in anchors])
|
|
neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors])
|
|
return EncodedAnchor(
|
|
prompt=prompt,
|
|
neg_prompt=neg_prompt,
|
|
multiplier_list=[a.multiplier for a in anchors]
|
|
)
|
|
|
|
|
|
def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]:
|
|
prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors)
|
|
neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors)
|
|
multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors)
|
|
|
|
anchors = []
|
|
for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits):
|
|
anchor = EncodedAnchor(
|
|
prompt=prompt,
|
|
neg_prompt=neg_prompt,
|
|
multiplier=multiplier.tolist()
|
|
)
|
|
anchors.append(anchor)
|
|
|
|
return anchors
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode_prompts_to_cache(
|
|
prompt_list: list[str],
|
|
sd: "StableDiffusion",
|
|
cache: Optional[PromptEmbedsCache] = None,
|
|
prompt_tensor_file: Optional[str] = None,
|
|
) -> PromptEmbedsCache:
|
|
# TODO: add support for larger prompts
|
|
if cache is None:
|
|
cache = PromptEmbedsCache()
|
|
|
|
if prompt_tensor_file is not None:
|
|
# check to see if it exists
|
|
if os.path.exists(prompt_tensor_file):
|
|
# load it.
|
|
print(f"Loading prompt tensors from {prompt_tensor_file}")
|
|
prompt_tensors = load_file(prompt_tensor_file, device='cpu')
|
|
# add them to the cache
|
|
for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
|
|
if prompt_txt.startswith("te:"):
|
|
prompt = prompt_txt[3:]
|
|
# text_embeds
|
|
text_embeds = prompt_tensor
|
|
pooled_embeds = None
|
|
# find pool embeds
|
|
if f"pe:{prompt}" in prompt_tensors:
|
|
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
|
|
|
|
# make it
|
|
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
|
|
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
|
|
|
|
if len(cache.prompts) == 0:
|
|
print("Prompt tensors not found. Encoding prompts..")
|
|
empty_prompt = ""
|
|
# encode empty_prompt
|
|
cache[empty_prompt] = sd.encode_prompt(empty_prompt)
|
|
|
|
for p in tqdm(prompt_list, desc="Encoding prompts", leave=False):
|
|
# build the cache
|
|
if cache[p] is None:
|
|
cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16)
|
|
|
|
# should we shard? It can get large
|
|
if prompt_tensor_file:
|
|
print(f"Saving prompt tensors to {prompt_tensor_file}")
|
|
state_dict = {}
|
|
for prompt_txt, prompt_embeds in cache.prompts.items():
|
|
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to(
|
|
"cpu", dtype=get_torch_dtype('fp16')
|
|
)
|
|
if prompt_embeds.pooled_embeds is not None:
|
|
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to(
|
|
"cpu",
|
|
dtype=get_torch_dtype('fp16')
|
|
)
|
|
save_file(state_dict, prompt_tensor_file)
|
|
|
|
return cache
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.config_modules import SliderTargetConfig
|
|
|
|
|
|
@torch.no_grad()
|
|
def build_prompt_pair_batch_from_cache(
|
|
cache: PromptEmbedsCache,
|
|
target: 'SliderTargetConfig',
|
|
neutral: Optional[str] = '',
|
|
) -> list[EncodedPromptPair]:
|
|
erase_negative = len(target.positive.strip()) == 0
|
|
enhance_positive = len(target.negative.strip()) == 0
|
|
|
|
both = not erase_negative and not enhance_positive
|
|
|
|
prompt_pair_batch = []
|
|
|
|
if both or erase_negative:
|
|
# print("Encoding erase negative")
|
|
prompt_pair_batch += [
|
|
# erase standard
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.positive}"],
|
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
negative_target=cache[f"{target.negative}"],
|
|
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
multiplier=target.multiplier,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
weight=target.weight
|
|
),
|
|
]
|
|
if both or enhance_positive:
|
|
# print("Encoding enhance positive")
|
|
prompt_pair_batch += [
|
|
# enhance standard, swap pos neg
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.negative}"],
|
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
negative_target=cache[f"{target.positive}"],
|
|
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
|
multiplier=target.multiplier,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
weight=target.weight
|
|
),
|
|
]
|
|
if both or enhance_positive:
|
|
# print("Encoding erase positive (inverse)")
|
|
prompt_pair_batch += [
|
|
# erase inverted
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.negative}"],
|
|
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
negative_target=cache[f"{target.positive}"],
|
|
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
empty_prompt=cache[""],
|
|
multiplier=target.multiplier * -1.0,
|
|
weight=target.weight
|
|
),
|
|
]
|
|
if both or erase_negative:
|
|
# print("Encoding enhance negative (inverse)")
|
|
prompt_pair_batch += [
|
|
# enhance inverted
|
|
EncodedPromptPair(
|
|
target_class=cache[target.target_class],
|
|
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
|
|
positive_target=cache[f"{target.positive}"],
|
|
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
|
negative_target=cache[f"{target.negative}"],
|
|
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
|
both_targets=cache[f"{target.positive} {target.negative}"],
|
|
neutral=cache[neutral],
|
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
|
empty_prompt=cache[""],
|
|
multiplier=target.multiplier * -1.0,
|
|
weight=target.weight
|
|
),
|
|
]
|
|
|
|
return prompt_pair_batch
|