diff --git a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py index 31da8402..9eb6e364 100644 --- a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py +++ b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py @@ -52,41 +52,14 @@ class TextualInversionTrainer(BaseSDTrainProcess): # very loosely based on this. very loosely # ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py - conditioned_prompts = [] + # make sure the embedding is in the prompts + conditioned_prompts = [self.embedding.inject_embedding_to_prompt( + x, + expand_token=True, + add_if_not_present=True, + ) for x in prompts] - for prompt in prompts: - # replace our name with the embedding - if self.embed_config.trigger in prompt: - # if the trigger is a part of the prompt, replace it with the token ids - prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) - if self.name in prompt: - # if the name is in the prompt, replace it with the trigger - prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) - if "[name]" in prompt: - # in [name] in prompt, replace it with the trigger - prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) - if self.embedding.get_embedding_string() not in prompt: - # add it to the beginning of the prompt - prompt = self.embedding.get_embedding_string() + " " + prompt - - conditioned_prompts.append(prompt) - - # # get embedding ids - # embedding_ids_list = [self.sd.tokenizer( - # text, - # padding="max_length", - # truncation=True, - # max_length=self.sd.tokenizer.model_max_length, - # return_tensors="pt", - # ).input_ids[0] for text in conditioned_prompts] - - # hidden_states = [] - # for embedding_ids, img in zip(embedding_ids_list, imgs): - # hidden_state = { - # "input_ids": embedding_ids, - # "pixel_values": img - # } - # hidden_states.append(hidden_state) + batch_size = imgs.shape[0] dtype = get_torch_dtype(self.train_config.dtype) imgs = imgs.to(self.device_torch, dtype=dtype) @@ -100,14 +73,14 @@ class TextualInversionTrainer(BaseSDTrainProcess): self.train_config.max_denoising_steps, device=self.device_torch ) - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) timesteps = timesteps.long() # get noise noise = self.sd.get_latent_noise( pixel_height=imgs.shape[2], pixel_width=imgs.shape[3], - batch_size=self.train_config.batch_size, + batch_size=batch_size, noise_offset=self.train_config.noise_offset ).to(self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index ae269e52..838bd8dc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess): prompt = sample_config.prompts[i] # add embedding if there is one + # note: diffusers will automatically expand the trigger to the number of added tokens + # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here if self.embedding is not None: - # replace our name with the embedding - if self.embed_config.trigger in prompt: - # if the trigger is a part of the prompt, replace it with the token ids - prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) - if self.name in prompt: - # if the name is in the prompt, replace it with the trigger - prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) - if "[name]" in prompt: - # in [name] in prompt, replace it with the trigger - prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) - if self.embedding.get_embedding_string() not in prompt: - # add it to the beginning of the prompt - prompt = self.embedding.get_embedding_string() + " " + prompt + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + ) gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt @@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.network.multiplier = prev_multiplier elif self.embedding is not None: + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + file_path = os.path.splitext(file_path)[0] + ".pt" self.embedding.save(file_path) else: self.sd.save( @@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess): def before_dataset_load(self): pass - def hook_train_loop(self, batch=None): + def hook_train_loop(self, batch): # return loss return 0.0 @@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + # resume state from embedding + self.step_num = self.embedding.step + # set trainable params params = self.embedding.get_trainable_params() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d79eeb53..46e22fda 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -54,9 +54,10 @@ class EmbeddingConfig: def __init__(self, **kwargs): self.trigger = kwargs.get('trigger', 'custom_embedding') self.tokens = kwargs.get('tokens', 4) - self.init_words = kwargs.get('init_phrase', '*') + self.init_words = kwargs.get('init_words', '*') self.save_format = kwargs.get('save_format', 'safetensors') + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') @@ -75,6 +76,7 @@ class TrainConfig: self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) + self.weight_jitter = kwargs.get('weight_jitter', 0.0) class ModelConfig: @@ -165,6 +167,7 @@ class DatasetConfig: self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) self.scale: float = kwargs.get('scale', 1.0) + self.buckets: bool = kwargs.get('buckets', False) class GenerateImageConfig: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index e399d1be..2ccf6890 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -14,6 +14,13 @@ import albumentations as A from toolkit.config_modules import DatasetConfig from toolkit.dataloader_mixins import CaptionMixin +BUCKET_STEPS = 64 + +def get_bucket_sizes_for_resolution(resolution: int) -> List[int]: + # make sure resolution is divisible by 8 + if resolution % 8 != 0: + resolution = resolution - (resolution % 8) + class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): @@ -357,6 +364,7 @@ class AiToolkitDataset(Dataset, CaptionMixin): def get_dataloader_from_datasets(dataset_options, batch_size=1): + # TODO do bucketing if dataset_options is None or len(dataset_options) == 0: return None diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 40e01250..3eb4483a 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -25,7 +25,9 @@ class Embedding: ): self.name = embed_config.trigger self.sd = sd + self.trigger = embed_config.trigger self.embed_config = embed_config + self.step = 0 # setup our embedding # Add the placeholder token in tokenizer placeholder_tokens = [self.embed_config.trigger] @@ -64,10 +66,7 @@ class Embedding: for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids): token_embeds[token_id] = token_embeds[initializer_token_id].clone() - # this doesnt seem to be used again - self.token_embeds = token_embeds - - # replace "[name] with this. This triggers it in the text encoder + # replace "[name] with this. on training. This is automatically generated in pipeline on inference self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids)) # returns the string to have in the prompt to trigger the embedding @@ -86,7 +85,7 @@ class Embedding: token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data # stack the tokens along batch axis adding that axis new_vector = torch.stack( - [token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids], + [token_embeds[token_id] for token_id in self.placeholder_token_ids], dim=0 ) return new_vector @@ -100,6 +99,39 @@ class Embedding: token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone() x = 1 + # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc + # however, on training we don't use that pipeline, so we have to do it ourselves + def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens] + + replace_with = self.embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + def save(self, filename): # todo check to see how to get the vector out of the embedding @@ -107,7 +139,7 @@ class Embedding: "string_to_token": {"*": 265}, "string_to_param": {"*": self.vec}, "name": self.name, - "step": 0, + "step": self.step, # todo get these "sd_checkpoint": None, "sd_checkpoint_name": None, @@ -182,4 +214,7 @@ class Embedding: raise Exception( f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + if 'step' in data: + self.step = int(data['step']) + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4e7c1141..e4667ee6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -435,7 +435,7 @@ class StableDiffusion: text_embeddings = train_tools.concat_prompt_embeddings( unconditional_embeddings, # negative embedding conditional_embeddings, # positive embedding - latents.shape[0], # batch size + 1, # batch size ) elif text_embeddings is None and conditional_embeddings is not None: # not doing cfg @@ -506,6 +506,17 @@ class StableDiffusion: latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor): + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2) + else: + raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + # predict the noise residual noise_pred = self.unet( latent_model_input,