Built base interfaces for a DTO to handle batch infomation transports for the dataloader

This commit is contained in:
Jaret Burkett
2023-08-28 12:43:31 -06:00
parent 71da78c8af
commit e866c75638
7 changed files with 186 additions and 76 deletions

View File

@@ -1,12 +1,11 @@
import os
from typing import Optional, TYPE_CHECKING, List
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import random
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import itertools
@@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER:
ENHANCE_NEGATIVE = 1
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
class EncodedPromptPair:
def __init__(
self,
@@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair(
latent_list.append(neg_latent)
return torch.cat(latent_list, dim=0)
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
if trigger is None:
return prompt
output_prompt = prompt
default_replacements = ["[name]", "[trigger]"]
replace_with = trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt