diff --git a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py new file mode 100644 index 00000000..31da8402 --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py @@ -0,0 +1,179 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset, ImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random +from toolkit.basic import value_map + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class TextualInversionTrainer(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + pass + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # keep original embeddings as reference + self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + pass + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts = batch + + # very loosely based on this. very loosely + # ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py + + conditioned_prompts = [] + + for prompt in prompts: + # replace our name with the embedding + if self.embed_config.trigger in prompt: + # if the trigger is a part of the prompt, replace it with the token ids + prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) + if self.name in prompt: + # if the name is in the prompt, replace it with the trigger + prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) + if "[name]" in prompt: + # in [name] in prompt, replace it with the trigger + prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) + if self.embedding.get_embedding_string() not in prompt: + # add it to the beginning of the prompt + prompt = self.embedding.get_embedding_string() + " " + prompt + + conditioned_prompts.append(prompt) + + # # get embedding ids + # embedding_ids_list = [self.sd.tokenizer( + # text, + # padding="max_length", + # truncation=True, + # max_length=self.sd.tokenizer.model_max_length, + # return_tensors="pt", + # ).input_ids[0] for text in conditioned_prompts] + + # hidden_states = [] + # for embedding_ids, img in zip(embedding_ids_list, imgs): + # hidden_state = { + # "input_ids": embedding_ids, + # "pixel_values": img + # } + # hidden_states.append(hidden_state) + + dtype = get_torch_dtype(self.train_config.dtype) + imgs = imgs.to(self.device_torch, dtype=dtype) + latents = self.sd.encode_images(imgs) + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=imgs.shape[2], + pixel_width=imgs.shape[3], + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # remove grads for these + noisy_latents.requires_grad = False + noise.requires_grad = False + + flush() + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # text encoding + embedding_list = [] + # embed the prompts + for prompt in conditioned_prompts: + embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) + index_no_updates[ + min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False + with torch.no_grad(): + self.sd.text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = self.orig_embeds_params[index_no_updates] + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/textual_inversion_trainer/__init__.py b/extensions_built_in/textual_inversion_trainer/__init__.py new file mode 100644 index 00000000..7167178f --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class OffsetSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "textual_inversion_trainer" + + # name is the name of the extension for printing + name = "Textual Inversion Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .TextualInversionTrainer import TextualInversionTrainer + return TextualInversionTrainer + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + OffsetSliderTrainer +] diff --git a/extensions_built_in/textual_inversion_trainer/config/train.example.yaml b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml new file mode 100644 index 00000000..8b0f4734 --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0606b9de..ae269e52 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,8 @@ from typing import Union from torch.utils.data import DataLoader +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -20,7 +22,7 @@ import torch from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig + GenerateImageConfig, EmbeddingConfig, DatasetConfig def flush(): @@ -30,6 +32,7 @@ def flush(): class BaseSDTrainProcess(BaseTrainProcess): sd: StableDiffusion + embedding: Union[Embedding, None] = None def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) @@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess): self.lr_scheduler = None self.data_loader: Union[DataLoader, None] = None + raw_datasets = self.get_conf('datasets', None) + self.datasets = None + if raw_datasets is not None and len(raw_datasets) > 0: + self.datasets = [DatasetConfig(**d) for d in raw_datasets] + + self.embed_config = None + embedding_raw = self.get_conf('embedding', None) + if embedding_raw is not None: + self.embed_config = EmbeddingConfig(**embedding_raw) + self.sd = StableDiffusion( device=self.device, model_config=self.model_config, @@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network = None + self.embedding = None def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') @@ -89,8 +103,26 @@ class BaseSDTrainProcess(BaseTrainProcess): output_path = os.path.join(sample_folder, filename) + prompt = sample_config.prompts[i] + + # add embedding if there is one + if self.embedding is not None: + # replace our name with the embedding + if self.embed_config.trigger in prompt: + # if the trigger is a part of the prompt, replace it with the token ids + prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) + if self.name in prompt: + # if the name is in the prompt, replace it with the trigger + prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) + if "[name]" in prompt: + # in [name] in prompt, replace it with the trigger + prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) + if self.embedding.get_embedding_string() not in prompt: + # add it to the beginning of the prompt + prompt = self.embedding.get_embedding_string() + " " + prompt + gen_img_config_list.append(GenerateImageConfig( - prompt=sample_config.prompts[i], # it will autoparse the prompt + prompt=prompt, # it will autoparse the prompt width=sample_config.width, height=sample_config.height, negative_prompt=sample_config.neg, @@ -175,6 +207,8 @@ class BaseSDTrainProcess(BaseTrainProcess): metadata=save_meta ) self.network.multiplier = prev_multiplier + elif self.embedding is not None: + self.embedding.save(file_path) else: self.sd.save( file_path, @@ -197,6 +231,9 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass + def before_dataset_load(self): + pass + def hook_train_loop(self, batch=None): # return loss return 0.0 @@ -208,6 +245,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors pattern = f"{self.job.name}*.safetensors" files = glob.glob(os.path.join(self.save_root, pattern)) + if len(files) > 0: + latest_file = max(files, key=os.path.getctime) + # try pt + pattern = f"{self.job.name}*.pt" + files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) return latest_file @@ -230,11 +272,21 @@ class BaseSDTrainProcess(BaseTrainProcess): def run(self): # run base process run BaseTrainProcess.run(self) + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size) + ### HOOK ### self.hook_before_model_load() # run base sd process run self.sd.load_model() + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + dtype = get_torch_dtype(self.train_config.dtype) # model is loaded from BaseSDProcess @@ -303,7 +355,18 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Loading from {latest_save_path}") self.load_weights(latest_save_path) self.network.multiplier = 1.0 + elif self.embed_config is not None: + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config + ) + latest_save_path = self.get_latest_save_path() + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + # set trainable params + params = self.embedding.get_trainable_params() else: params = [] diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f95196a8..d79eeb53 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional +from typing import List, Optional, Literal import random @@ -50,6 +50,13 @@ class NetworkConfig: self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) +class EmbeddingConfig: + def __init__(self, **kwargs): + self.trigger = kwargs.get('trigger', 'custom_embedding') + self.tokens = kwargs.get('tokens', 4) + self.init_words = kwargs.get('init_phrase', '*') + self.save_format = kwargs.get('save_format', 'safetensors') + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') @@ -146,6 +153,20 @@ class SliderConfig: self.targets.append(target) +class DatasetConfig: + caption_type: Literal["txt", "caption"] = 'txt' + + def __init__(self, **kwargs): + self.type = kwargs.get('type', 'image') # sd, slider, reference + self.folder_path: str = kwargs.get('folder_path', None) + self.default_caption: str = kwargs.get('default_caption', None) + self.caption_type: str = kwargs.get('caption_type', None) + self.random_scale: bool = kwargs.get('random_scale', False) + self.random_crop: bool = kwargs.get('random_crop', False) + self.resolution: int = kwargs.get('resolution', 512) + self.scale: float = kwargs.get('scale', 1.0) + + class GenerateImageConfig: def __init__( self, diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index e4ddd51d..e399d1be 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,23 +1,33 @@ import os import random +from typing import List import cv2 import numpy as np from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A +from toolkit.config_modules import DatasetConfig +from toolkit.dataloader_mixins import CaptionMixin -class ImageDataset(Dataset): + +class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): self.config = config self.name = self.get_config('name', 'dataset') self.path = self.get_config('path', required=True) self.scale = self.get_config('scale', 1) self.random_scale = self.get_config('random_scale', False) + self.include_prompt = self.get_config('include_prompt', False) + self.default_prompt = self.get_config('default_prompt', '') + if self.include_prompt: + self.caption_type = self.get_config('caption_type', 'txt') + else: + self.caption_type = None # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) @@ -81,7 +91,11 @@ class ImageDataset(Dataset): img = self.transform(img) - return img + if self.include_prompt: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img class Augments: @@ -268,3 +282,101 @@ class PairedImageDataset(Dataset): img = self.transform(img) return img, prompt, (self.neg_weight, self.pos_weight) + + +class AiToolkitDataset(Dataset, CaptionMixin): + def __init__(self, dataset_config: 'DatasetConfig'): + self.dataset_config = dataset_config + self.folder_path = dataset_config.folder_path + self.caption_type = dataset_config.caption_type + self.default_caption = dataset_config.default_caption + self.random_scale = dataset_config.random_scale + self.scale = dataset_config.scale + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop + self.resolution = dataset_config.resolution + + # get the file list + self.file_list = [ + os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + + # this might take a while + print(f" - Preprocessing image dimensions") + new_file_list = [] + bad_count = 0 + for file in tqdm(self.file_list): + img = Image.open(file) + if int(min(img.size) * self.scale) >= self.resolution: + new_file_list.append(file) + else: + bad_count += 1 + + print(f" - Found {len(self.file_list)} images") + print(f" - Found {bad_count} images that are too small") + assert len(self.file_list) > 0, f"no images found in {self.folder_path}" + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] + ]) + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + img_path = self.file_list[index] + img = exif_transpose(Image.open(img_path)).convert('RGB') + + # Downscale the source image first + img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + min_img_size = min(img.size) + + if self.random_crop: + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) + img = img.resize((scale_size, scale_size), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + + img = self.transform(img) + + if self.caption_type is not None: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img + + +def get_dataloader_from_datasets(dataset_options, batch_size=1): + if dataset_options is None or len(dataset_options) == 0: + return None + + datasets = [] + for dataset_option in dataset_options: + if isinstance(dataset_option, DatasetConfig): + config = dataset_option + else: + config = DatasetConfig(**dataset_option) + if config.type == 'image': + dataset = AiToolkitDataset(config) + datasets.append(dataset) + else: + raise ValueError(f"invalid dataset type: {config.type}") + + concatenated_dataset = ConcatDataset(datasets) + data_loader = DataLoader( + concatenated_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2 + ) + return data_loader diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py new file mode 100644 index 00000000..913075fa --- /dev/null +++ b/toolkit/dataloader_mixins.py @@ -0,0 +1,43 @@ +import os + + +class CaptionMixin: + def get_caption_item(self, index): + if not hasattr(self, 'caption_type'): + raise Exception('caption_type not found on class instance') + if not hasattr(self, 'file_list'): + raise Exception('file_list not found on class instance') + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] + prompt_path = path_no_ext + '.txt' + if not os.path.exists(prompt_path): + path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] + prompt_path = path_no_ext + '.txt' + else: + img_path = img_path_or_tuple + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + '.txt' + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = '' + # get default_prompt if it exists on the class instance + if hasattr(self, 'default_prompt'): + prompt = self.default_prompt + if hasattr(self, 'default_caption'): + prompt = self.default_caption + return prompt diff --git a/toolkit/embedding.py b/toolkit/embedding.py new file mode 100644 index 00000000..40e01250 --- /dev/null +++ b/toolkit/embedding.py @@ -0,0 +1,185 @@ +import json +import os +from collections import OrderedDict + +import safetensors +import torch +from typing import TYPE_CHECKING + +from safetensors.torch import save_file + +from toolkit.metadata import get_meta_for_safetensors + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import EmbeddingConfig + + +# this is a frankenstein mix of automatic1111 and my own code + +class Embedding: + def __init__( + self, + sd: 'StableDiffusion', + embed_config: 'EmbeddingConfig' + ): + self.name = embed_config.trigger + self.sd = sd + self.embed_config = embed_config + # setup our embedding + # Add the placeholder token in tokenizer + placeholder_tokens = [self.embed_config.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.embed_config.tokens): + additional_tokens.append(f"{self.embed_config.trigger}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.embed_config.tokens: + raise ValueError( + f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.embed_config.tokens: + init_token_ids = init_token_ids[:self.embed_config.tokens] + elif len(init_token_ids) < self.embed_config.tokens: + pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) + + self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + # todo SDXL has 2 text encoders, need to do both for all of this + self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # this doesnt seem to be used again + self.token_embeds = token_embeds + + # replace "[name] with this. This triggers it in the text encoder + self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids)) + + # returns the string to have in the prompt to trigger the embedding + def get_embedding_string(self): + return self.embedding_tokens + + def get_trainable_params(self): + # todo only get this one as we could have more than one + return self.sd.text_encoder.get_input_embeddings().parameters() + + # make setter and getter for vec + @property + def vec(self): + # should we get params instead + # create vector from token embeds + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + # stack the tokens along batch axis adding that axis + new_vector = torch.stack( + [token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids], + dim=0 + ) + return new_vector + + @vec.setter + def vec(self, new_vector): + # shape is (1, 768) for SD 1.5 for 1 token + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + for i in range(new_vector.shape[0]): + # apply the weights to the placeholder tokens while preserving gradient + token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone() + x = 1 + + def save(self, filename): + # todo check to see how to get the vector out of the embedding + + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": 0, + # todo get these + "sd_checkpoint": None, + "sd_checkpoint_name": None, + "notes": None, + } + if filename.endswith('.pt'): + torch.save(embedding_data, filename) + elif filename.endswith('.bin'): + torch.save(embedding_data, filename) + elif filename.endswith('.safetensors'): + # save the embedding as a safetensors file + state_dict = {"emb_params": self.vec} + # add all embedding data (except string_to_param), to metadata + metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) + metadata["string_to_param"] = {"*": "emb_params"} + save_meta = get_meta_for_safetensors(metadata, name=self.name) + save_file(state_dict, filename, metadata=save_meta) + + def load_embedding_from_file(self, file_path, device): + # full path + path = os.path.realpath(file_path) + filename = os.path.basename(path) + name, ext = os.path.splitext(filename) + ext = ext.upper() + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + if ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + # rebuild the embedding from the safetensors file if it has it + tensors = {} + with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + tensors[k] = f.get_tensor(k) + # data = safetensors.torch.load_file(path, device="cpu") + if metadata and 'string_to_param' in metadata and 'emb_params' in tensors: + # our format + def try_json(v): + try: + return json.loads(v) + except: + return v + + data = {k: try_json(v) for k, v in metadata.items()} + data['string_to_param'] = {'*': tensors['emb_params']} + else: + # old format + data = tensors + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, + '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception( + f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7db5e05b..4e7c1141 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -520,6 +520,11 @@ class StableDiffusion: noise_pred_text - noise_pred_uncond ) + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + return noise_pred # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746