mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -52,41 +52,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
|
||||
# very loosely based on this. very loosely
|
||||
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||
|
||||
conditioned_prompts = []
|
||||
# make sure the embedding is in the prompts
|
||||
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
|
||||
x,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
) for x in prompts]
|
||||
|
||||
for prompt in prompts:
|
||||
# replace our name with the embedding
|
||||
if self.embed_config.trigger in prompt:
|
||||
# if the trigger is a part of the prompt, replace it with the token ids
|
||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
||||
if self.name in prompt:
|
||||
# if the name is in the prompt, replace it with the trigger
|
||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
||||
if "[name]" in prompt:
|
||||
# in [name] in prompt, replace it with the trigger
|
||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
||||
if self.embedding.get_embedding_string() not in prompt:
|
||||
# add it to the beginning of the prompt
|
||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
||||
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
# # get embedding ids
|
||||
# embedding_ids_list = [self.sd.tokenizer(
|
||||
# text,
|
||||
# padding="max_length",
|
||||
# truncation=True,
|
||||
# max_length=self.sd.tokenizer.model_max_length,
|
||||
# return_tensors="pt",
|
||||
# ).input_ids[0] for text in conditioned_prompts]
|
||||
|
||||
# hidden_states = []
|
||||
# for embedding_ids, img in zip(embedding_ids_list, imgs):
|
||||
# hidden_state = {
|
||||
# "input_ids": embedding_ids,
|
||||
# "pixel_values": img
|
||||
# }
|
||||
# hidden_states.append(hidden_state)
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
@@ -100,14 +73,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=self.train_config.batch_size,
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
@@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt = sample_config.prompts[i]
|
||||
|
||||
# add embedding if there is one
|
||||
# note: diffusers will automatically expand the trigger to the number of added tokens
|
||||
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||
if self.embedding is not None:
|
||||
# replace our name with the embedding
|
||||
if self.embed_config.trigger in prompt:
|
||||
# if the trigger is a part of the prompt, replace it with the token ids
|
||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
||||
if self.name in prompt:
|
||||
# if the name is in the prompt, replace it with the trigger
|
||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
||||
if "[name]" in prompt:
|
||||
# in [name] in prompt, replace it with the trigger
|
||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
||||
if self.embedding.get_embedding_string() not in prompt:
|
||||
# add it to the beginning of the prompt
|
||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
@@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
elif self.embedding is not None:
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
@@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch=None):
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
@@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
|
||||
@@ -54,9 +54,10 @@ class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||
self.tokens = kwargs.get('tokens', 4)
|
||||
self.init_words = kwargs.get('init_phrase', '*')
|
||||
self.init_words = kwargs.get('init_words', '*')
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
@@ -75,6 +76,7 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -165,6 +167,7 @@ class DatasetConfig:
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
self.scale: float = kwargs.get('scale', 1.0)
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
|
||||
@@ -14,6 +14,13 @@ import albumentations as A
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin
|
||||
|
||||
BUCKET_STEPS = 64
|
||||
|
||||
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
|
||||
# make sure resolution is divisible by 8
|
||||
if resolution % 8 != 0:
|
||||
resolution = resolution - (resolution % 8)
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
@@ -357,6 +364,7 @@ class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
# TODO do bucketing
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ class Embedding:
|
||||
):
|
||||
self.name = embed_config.trigger
|
||||
self.sd = sd
|
||||
self.trigger = embed_config.trigger
|
||||
self.embed_config = embed_config
|
||||
self.step = 0
|
||||
# setup our embedding
|
||||
# Add the placeholder token in tokenizer
|
||||
placeholder_tokens = [self.embed_config.trigger]
|
||||
@@ -64,10 +66,7 @@ class Embedding:
|
||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# this doesnt seem to be used again
|
||||
self.token_embeds = token_embeds
|
||||
|
||||
# replace "[name] with this. This triggers it in the text encoder
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
||||
|
||||
# returns the string to have in the prompt to trigger the embedding
|
||||
@@ -86,7 +85,7 @@ class Embedding:
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
# stack the tokens along batch axis adding that axis
|
||||
new_vector = torch.stack(
|
||||
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
|
||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
||||
dim=0
|
||||
)
|
||||
return new_vector
|
||||
@@ -100,6 +99,39 @@ class Embedding:
|
||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
||||
x = 1
|
||||
|
||||
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
||||
|
||||
replace_with = self.embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def save(self, filename):
|
||||
# todo check to see how to get the vector out of the embedding
|
||||
|
||||
@@ -107,7 +139,7 @@ class Embedding:
|
||||
"string_to_token": {"*": 265},
|
||||
"string_to_param": {"*": self.vec},
|
||||
"name": self.name,
|
||||
"step": 0,
|
||||
"step": self.step,
|
||||
# todo get these
|
||||
"sd_checkpoint": None,
|
||||
"sd_checkpoint_name": None,
|
||||
@@ -182,4 +214,7 @@ class Embedding:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
|
||||
@@ -435,7 +435,7 @@ class StableDiffusion:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
latents.shape[0], # batch size
|
||||
1, # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
@@ -506,6 +506,17 @@ class StableDiffusion:
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
ts_bs = timestep.shape[0]
|
||||
if ts_bs != latent_model_input.shape[0]:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
|
||||
Reference in New Issue
Block a user