Built base interfaces for a DTO to handle batch infomation transports for the dataloader

This commit is contained in:
Jaret Burkett
2023-08-28 12:43:31 -06:00
parent 71da78c8af
commit e866c75638
7 changed files with 186 additions and 76 deletions

View File

@@ -16,6 +16,7 @@ from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
@@ -63,25 +64,7 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
# if is type checking
@@ -708,38 +691,12 @@ class StableDiffusion:
raise ValueError(f"Unknown weight name: {name}")
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
if trigger is None:
return prompt
output_prompt = prompt
default_replacements = ["[name]", "[trigger]"]
num_times_trigger_exists = prompt.count(trigger)
replace_with = trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
return inject_trigger_into_prompt(
prompt,
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present,
)
def state_dict(self, vae=True, text_encoder=True, unet=True):
state_dict = OrderedDict()