Tied in ant tested TI script

This commit is contained in:
Jaret Burkett
2023-08-23 13:26:28 -06:00
parent 2e6c55c720
commit d298240cec
6 changed files with 89 additions and 58 deletions

View File

@@ -52,41 +52,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
# very loosely based on this. very loosely
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
conditioned_prompts = []
# make sure the embedding is in the prompts
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
x,
expand_token=True,
add_if_not_present=True,
) for x in prompts]
for prompt in prompts:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
conditioned_prompts.append(prompt)
# # get embedding ids
# embedding_ids_list = [self.sd.tokenizer(
# text,
# padding="max_length",
# truncation=True,
# max_length=self.sd.tokenizer.model_max_length,
# return_tensors="pt",
# ).input_ids[0] for text in conditioned_prompts]
# hidden_states = []
# for embedding_ids, img in zip(embedding_ids_list, imgs):
# hidden_state = {
# "input_ids": embedding_ids,
# "pixel_values": img
# }
# hidden_states.append(hidden_state)
batch_size = imgs.shape[0]
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
@@ -100,14 +73,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=self.train_config.batch_size,
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)

View File

@@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
prompt = sample_config.prompts[i]
# add embedding if there is one
# note: diffusers will automatically expand the trigger to the number of added tokens
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
if self.embedding is not None:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
)
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
@@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
self.network.multiplier = prev_multiplier
elif self.embedding is not None:
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(file_path)[0] + ".pt"
self.embedding.save(file_path)
else:
self.sd.save(
@@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
def before_dataset_load(self):
pass
def hook_train_loop(self, batch=None):
def hook_train_loop(self, batch):
# return loss
return 0.0
@@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
# set trainable params
params = self.embedding.get_trainable_params()

View File

@@ -54,9 +54,10 @@ class EmbeddingConfig:
def __init__(self, **kwargs):
self.trigger = kwargs.get('trigger', 'custom_embedding')
self.tokens = kwargs.get('tokens', 4)
self.init_words = kwargs.get('init_phrase', '*')
self.init_words = kwargs.get('init_words', '*')
self.save_format = kwargs.get('save_format', 'safetensors')
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
@@ -75,6 +76,7 @@ class TrainConfig:
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
class ModelConfig:
@@ -165,6 +167,7 @@ class DatasetConfig:
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
self.scale: float = kwargs.get('scale', 1.0)
self.buckets: bool = kwargs.get('buckets', False)
class GenerateImageConfig:

View File

@@ -14,6 +14,13 @@ import albumentations as A
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin
BUCKET_STEPS = 64
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
# make sure resolution is divisible by 8
if resolution % 8 != 0:
resolution = resolution - (resolution % 8)
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
@@ -357,6 +364,7 @@ class AiToolkitDataset(Dataset, CaptionMixin):
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
if dataset_options is None or len(dataset_options) == 0:
return None

View File

@@ -25,7 +25,9 @@ class Embedding:
):
self.name = embed_config.trigger
self.sd = sd
self.trigger = embed_config.trigger
self.embed_config = embed_config
self.step = 0
# setup our embedding
# Add the placeholder token in tokenizer
placeholder_tokens = [self.embed_config.trigger]
@@ -64,10 +66,7 @@ class Embedding:
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# this doesnt seem to be used again
self.token_embeds = token_embeds
# replace "[name] with this. This triggers it in the text encoder
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
# returns the string to have in the prompt to trigger the embedding
@@ -86,7 +85,7 @@ class Embedding:
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
dim=0
)
return new_vector
@@ -100,6 +99,39 @@ class Embedding:
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
replace_with = self.embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
def save(self, filename):
# todo check to see how to get the vector out of the embedding
@@ -107,7 +139,7 @@ class Embedding:
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": 0,
"step": self.step,
# todo get these
"sd_checkpoint": None,
"sd_checkpoint_name": None,
@@ -182,4 +214,7 @@ class Embedding:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)

View File

@@ -435,7 +435,7 @@ class StableDiffusion:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
latents.shape[0], # batch size
1, # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
@@ -506,6 +506,17 @@ class StableDiffusion:
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor):
ts_bs = timestep.shape[0]
if ts_bs != latent_model_input.shape[0]:
if ts_bs == 1:
timestep = torch.cat([timestep] * latent_model_input.shape[0])
elif ts_bs * 2 == latent_model_input.shape[0]:
timestep = torch.cat([timestep] * 2)
else:
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# predict the noise residual
noise_pred = self.unet(
latent_model_input,