mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Built base interfaces for a DTO to handle batch infomation transports for the dataloader
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
from typing import Optional, TYPE_CHECKING, List
|
||||
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import itertools
|
||||
|
||||
@@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER:
|
||||
ENHANCE_NEGATIVE = 1
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: Union[torch.Tensor, None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
self.pooled_embeds = args[1]
|
||||
else:
|
||||
# sdv1.x, sdv2.x
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair(
|
||||
latent_list.append(neg_latent)
|
||||
|
||||
return torch.cat(latent_list, dim=0)
|
||||
|
||||
|
||||
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
Reference in New Issue
Block a user