From 2e6c55c7202a45fae0598a36a2ae1fd4d13c0b9a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 22 Aug 2023 21:02:38 -0600 Subject: [PATCH 1/3] WIP creating textual inversion training script --- .../TextualInversionTrainer.py | 179 +++++++++++++++++ .../textual_inversion_trainer/__init__.py | 25 +++ .../config/train.example.yaml | 107 ++++++++++ jobs/process/BaseSDTrainProcess.py | 67 ++++++- toolkit/config_modules.py | 23 ++- toolkit/data_loader.py | 118 ++++++++++- toolkit/dataloader_mixins.py | 43 ++++ toolkit/embedding.py | 185 ++++++++++++++++++ toolkit/stable_diffusion_model.py | 5 + 9 files changed, 746 insertions(+), 6 deletions(-) create mode 100644 extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py create mode 100644 extensions_built_in/textual_inversion_trainer/__init__.py create mode 100644 extensions_built_in/textual_inversion_trainer/config/train.example.yaml create mode 100644 toolkit/dataloader_mixins.py create mode 100644 toolkit/embedding.py diff --git a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py new file mode 100644 index 00000000..31da8402 --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py @@ -0,0 +1,179 @@ +import copy +import random +from collections import OrderedDict +import os +from contextlib import nullcontext +from typing import Optional, Union, List +from torch.utils.data import ConcatDataset, DataLoader + +from toolkit.config_modules import ReferenceDatasetConfig +from toolkit.data_loader import PairedImageDataset, ImageDataset +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset +import gc +from toolkit import train_tools +import torch +from jobs.process import BaseSDTrainProcess +import random +from toolkit.basic import value_map + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class TextualInversionTrainer(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + pass + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # keep original embeddings as reference + self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + pass + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts = batch + + # very loosely based on this. very loosely + # ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py + + conditioned_prompts = [] + + for prompt in prompts: + # replace our name with the embedding + if self.embed_config.trigger in prompt: + # if the trigger is a part of the prompt, replace it with the token ids + prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) + if self.name in prompt: + # if the name is in the prompt, replace it with the trigger + prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) + if "[name]" in prompt: + # in [name] in prompt, replace it with the trigger + prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) + if self.embedding.get_embedding_string() not in prompt: + # add it to the beginning of the prompt + prompt = self.embedding.get_embedding_string() + " " + prompt + + conditioned_prompts.append(prompt) + + # # get embedding ids + # embedding_ids_list = [self.sd.tokenizer( + # text, + # padding="max_length", + # truncation=True, + # max_length=self.sd.tokenizer.model_max_length, + # return_tensors="pt", + # ).input_ids[0] for text in conditioned_prompts] + + # hidden_states = [] + # for embedding_ids, img in zip(embedding_ids_list, imgs): + # hidden_state = { + # "input_ids": embedding_ids, + # "pixel_values": img + # } + # hidden_states.append(hidden_state) + + dtype = get_torch_dtype(self.train_config.dtype) + imgs = imgs.to(self.device_torch, dtype=dtype) + latents = self.sd.encode_images(imgs) + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=imgs.shape[2], + pixel_width=imgs.shape[3], + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # remove grads for these + noisy_latents.requires_grad = False + noise.requires_grad = False + + flush() + + self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # text encoding + embedding_list = [] + # embed the prompts + for prompt in conditioned_prompts: + embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) + index_no_updates[ + min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False + with torch.no_grad(): + self.sd.text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = self.orig_embeds_params[index_no_updates] + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + + return loss_dict + # end hook_train_loop diff --git a/extensions_built_in/textual_inversion_trainer/__init__.py b/extensions_built_in/textual_inversion_trainer/__init__.py new file mode 100644 index 00000000..7167178f --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class OffsetSliderTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "textual_inversion_trainer" + + # name is the name of the extension for printing + name = "Textual Inversion Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .TextualInversionTrainer import TextualInversionTrainer + return TextualInversionTrainer + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + OffsetSliderTrainer +] diff --git a/extensions_built_in/textual_inversion_trainer/config/train.example.yaml b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml new file mode 100644 index 00000000..8b0f4734 --- /dev/null +++ b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml @@ -0,0 +1,107 @@ +--- +job: extension +config: + name: example_name + process: + - type: 'image_reference_slider_trainer' + training_folder: "/mnt/Train/out/LoRA" + device: cuda:0 + # for tensorboard logging + log_dir: "/home/jaret/Dev/.tensorboard" + network: + type: "lora" + linear: 8 + linear_alpha: 8 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 5000 + lr: 1e-4 + train_unet: true + gradient_checkpointing: true + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 1 + dtype: bf16 + xformers: true + skip_first_sample: true + noise_offset: 0.0 + model: + name_or_path: "/path/to/model.safetensors" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + save: + dtype: float16 # precision to save + save_every: 1000 # save every this many steps + max_step_saves_to_keep: 2 # only affects step counts + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of a woman with red hair taking a selfie --m -3" + - "photo of a woman with red hair taking a selfie --m -1" + - "photo of a woman with red hair taking a selfie --m 1" + - "photo of a woman with red hair taking a selfie --m 3" + - "close up photo of a man smiling at the camera, in a tank top --m -3" + - "close up photo of a man smiling at the camera, in a tank top--m -1" + - "close up photo of a man smiling at the camera, in a tank top --m 1" + - "close up photo of a man smiling at the camera, in a tank top --m 3" + - "photo of a blonde woman smiling, barista --m -3" + - "photo of a blonde woman smiling, barista --m -1" + - "photo of a blonde woman smiling, barista --m 1" + - "photo of a blonde woman smiling, barista --m 3" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m -1" + - "photo of a Christina Hendricks --m 1" + - "photo of a Christina Hendricks --m 3" + - "photo of a Christina Ricci --m -3" + - "photo of a Christina Ricci --m -1" + - "photo of a Christina Ricci --m 1" + - "photo of a Christina Ricci --m 3" + neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + + slider: + datasets: + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 2.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + - pair_folder: "/path/to/folder/side/by/side/images" + network_weight: 4.0 + target_class: "" # only used as default if caption txt are not present + size: 512 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + trigger_words: + - put + - trigger + - words + - here + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0606b9de..ae269e52 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,8 @@ from typing import Union from torch.utils.data import DataLoader +from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -20,7 +22,7 @@ import torch from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig + GenerateImageConfig, EmbeddingConfig, DatasetConfig def flush(): @@ -30,6 +32,7 @@ def flush(): class BaseSDTrainProcess(BaseTrainProcess): sd: StableDiffusion + embedding: Union[Embedding, None] = None def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) @@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess): self.lr_scheduler = None self.data_loader: Union[DataLoader, None] = None + raw_datasets = self.get_conf('datasets', None) + self.datasets = None + if raw_datasets is not None and len(raw_datasets) > 0: + self.datasets = [DatasetConfig(**d) for d in raw_datasets] + + self.embed_config = None + embedding_raw = self.get_conf('embedding', None) + if embedding_raw is not None: + self.embed_config = EmbeddingConfig(**embedding_raw) + self.sd = StableDiffusion( device=self.device, model_config=self.model_config, @@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network = None + self.embedding = None def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') @@ -89,8 +103,26 @@ class BaseSDTrainProcess(BaseTrainProcess): output_path = os.path.join(sample_folder, filename) + prompt = sample_config.prompts[i] + + # add embedding if there is one + if self.embedding is not None: + # replace our name with the embedding + if self.embed_config.trigger in prompt: + # if the trigger is a part of the prompt, replace it with the token ids + prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) + if self.name in prompt: + # if the name is in the prompt, replace it with the trigger + prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) + if "[name]" in prompt: + # in [name] in prompt, replace it with the trigger + prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) + if self.embedding.get_embedding_string() not in prompt: + # add it to the beginning of the prompt + prompt = self.embedding.get_embedding_string() + " " + prompt + gen_img_config_list.append(GenerateImageConfig( - prompt=sample_config.prompts[i], # it will autoparse the prompt + prompt=prompt, # it will autoparse the prompt width=sample_config.width, height=sample_config.height, negative_prompt=sample_config.neg, @@ -175,6 +207,8 @@ class BaseSDTrainProcess(BaseTrainProcess): metadata=save_meta ) self.network.multiplier = prev_multiplier + elif self.embedding is not None: + self.embedding.save(file_path) else: self.sd.save( file_path, @@ -197,6 +231,9 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass + def before_dataset_load(self): + pass + def hook_train_loop(self, batch=None): # return loss return 0.0 @@ -208,6 +245,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors pattern = f"{self.job.name}*.safetensors" files = glob.glob(os.path.join(self.save_root, pattern)) + if len(files) > 0: + latest_file = max(files, key=os.path.getctime) + # try pt + pattern = f"{self.job.name}*.pt" + files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) return latest_file @@ -230,11 +272,21 @@ class BaseSDTrainProcess(BaseTrainProcess): def run(self): # run base process run BaseTrainProcess.run(self) + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size) + ### HOOK ### self.hook_before_model_load() # run base sd process run self.sd.load_model() + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() + dtype = get_torch_dtype(self.train_config.dtype) # model is loaded from BaseSDProcess @@ -303,7 +355,18 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Loading from {latest_save_path}") self.load_weights(latest_save_path) self.network.multiplier = 1.0 + elif self.embed_config is not None: + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config + ) + latest_save_path = self.get_latest_save_path() + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + # set trainable params + params = self.embedding.get_trainable_params() else: params = [] diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f95196a8..d79eeb53 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional +from typing import List, Optional, Literal import random @@ -50,6 +50,13 @@ class NetworkConfig: self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) +class EmbeddingConfig: + def __init__(self, **kwargs): + self.trigger = kwargs.get('trigger', 'custom_embedding') + self.tokens = kwargs.get('tokens', 4) + self.init_words = kwargs.get('init_phrase', '*') + self.save_format = kwargs.get('save_format', 'safetensors') + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') @@ -146,6 +153,20 @@ class SliderConfig: self.targets.append(target) +class DatasetConfig: + caption_type: Literal["txt", "caption"] = 'txt' + + def __init__(self, **kwargs): + self.type = kwargs.get('type', 'image') # sd, slider, reference + self.folder_path: str = kwargs.get('folder_path', None) + self.default_caption: str = kwargs.get('default_caption', None) + self.caption_type: str = kwargs.get('caption_type', None) + self.random_scale: bool = kwargs.get('random_scale', False) + self.random_crop: bool = kwargs.get('random_crop', False) + self.resolution: int = kwargs.get('resolution', 512) + self.scale: float = kwargs.get('scale', 1.0) + + class GenerateImageConfig: def __init__( self, diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index e4ddd51d..e399d1be 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,23 +1,33 @@ import os import random +from typing import List import cv2 import numpy as np from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A +from toolkit.config_modules import DatasetConfig +from toolkit.dataloader_mixins import CaptionMixin -class ImageDataset(Dataset): + +class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): self.config = config self.name = self.get_config('name', 'dataset') self.path = self.get_config('path', required=True) self.scale = self.get_config('scale', 1) self.random_scale = self.get_config('random_scale', False) + self.include_prompt = self.get_config('include_prompt', False) + self.default_prompt = self.get_config('default_prompt', '') + if self.include_prompt: + self.caption_type = self.get_config('caption_type', 'txt') + else: + self.caption_type = None # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) @@ -81,7 +91,11 @@ class ImageDataset(Dataset): img = self.transform(img) - return img + if self.include_prompt: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img class Augments: @@ -268,3 +282,101 @@ class PairedImageDataset(Dataset): img = self.transform(img) return img, prompt, (self.neg_weight, self.pos_weight) + + +class AiToolkitDataset(Dataset, CaptionMixin): + def __init__(self, dataset_config: 'DatasetConfig'): + self.dataset_config = dataset_config + self.folder_path = dataset_config.folder_path + self.caption_type = dataset_config.caption_type + self.default_caption = dataset_config.default_caption + self.random_scale = dataset_config.random_scale + self.scale = dataset_config.scale + # we always random crop if random scale is enabled + self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop + self.resolution = dataset_config.resolution + + # get the file list + self.file_list = [ + os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + + # this might take a while + print(f" - Preprocessing image dimensions") + new_file_list = [] + bad_count = 0 + for file in tqdm(self.file_list): + img = Image.open(file) + if int(min(img.size) * self.scale) >= self.resolution: + new_file_list.append(file) + else: + bad_count += 1 + + print(f" - Found {len(self.file_list)} images") + print(f" - Found {bad_count} images that are too small") + assert len(self.file_list) > 0, f"no images found in {self.folder_path}" + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] + ]) + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + img_path = self.file_list[index] + img = exif_transpose(Image.open(img_path)).convert('RGB') + + # Downscale the source image first + img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + min_img_size = min(img.size) + + if self.random_crop: + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) + img = img.resize((scale_size, scale_size), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + + img = self.transform(img) + + if self.caption_type is not None: + prompt = self.get_caption_item(index) + return img, prompt + else: + return img + + +def get_dataloader_from_datasets(dataset_options, batch_size=1): + if dataset_options is None or len(dataset_options) == 0: + return None + + datasets = [] + for dataset_option in dataset_options: + if isinstance(dataset_option, DatasetConfig): + config = dataset_option + else: + config = DatasetConfig(**dataset_option) + if config.type == 'image': + dataset = AiToolkitDataset(config) + datasets.append(dataset) + else: + raise ValueError(f"invalid dataset type: {config.type}") + + concatenated_dataset = ConcatDataset(datasets) + data_loader = DataLoader( + concatenated_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2 + ) + return data_loader diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py new file mode 100644 index 00000000..913075fa --- /dev/null +++ b/toolkit/dataloader_mixins.py @@ -0,0 +1,43 @@ +import os + + +class CaptionMixin: + def get_caption_item(self, index): + if not hasattr(self, 'caption_type'): + raise Exception('caption_type not found on class instance') + if not hasattr(self, 'file_list'): + raise Exception('file_list not found on class instance') + img_path_or_tuple = self.file_list[index] + if isinstance(img_path_or_tuple, tuple): + # check if either has a prompt file + path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] + prompt_path = path_no_ext + '.txt' + if not os.path.exists(prompt_path): + path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] + prompt_path = path_no_ext + '.txt' + else: + img_path = img_path_or_tuple + # see if prompt file exists + path_no_ext = os.path.splitext(img_path)[0] + prompt_path = path_no_ext + '.txt' + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = '' + # get default_prompt if it exists on the class instance + if hasattr(self, 'default_prompt'): + prompt = self.default_prompt + if hasattr(self, 'default_caption'): + prompt = self.default_caption + return prompt diff --git a/toolkit/embedding.py b/toolkit/embedding.py new file mode 100644 index 00000000..40e01250 --- /dev/null +++ b/toolkit/embedding.py @@ -0,0 +1,185 @@ +import json +import os +from collections import OrderedDict + +import safetensors +import torch +from typing import TYPE_CHECKING + +from safetensors.torch import save_file + +from toolkit.metadata import get_meta_for_safetensors + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import EmbeddingConfig + + +# this is a frankenstein mix of automatic1111 and my own code + +class Embedding: + def __init__( + self, + sd: 'StableDiffusion', + embed_config: 'EmbeddingConfig' + ): + self.name = embed_config.trigger + self.sd = sd + self.embed_config = embed_config + # setup our embedding + # Add the placeholder token in tokenizer + placeholder_tokens = [self.embed_config.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.embed_config.tokens): + additional_tokens.append(f"{self.embed_config.trigger}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.embed_config.tokens: + raise ValueError( + f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.embed_config.tokens: + init_token_ids = init_token_ids[:self.embed_config.tokens] + elif len(init_token_ids) < self.embed_config.tokens: + pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) + + self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + # todo SDXL has 2 text encoders, need to do both for all of this + self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # this doesnt seem to be used again + self.token_embeds = token_embeds + + # replace "[name] with this. This triggers it in the text encoder + self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids)) + + # returns the string to have in the prompt to trigger the embedding + def get_embedding_string(self): + return self.embedding_tokens + + def get_trainable_params(self): + # todo only get this one as we could have more than one + return self.sd.text_encoder.get_input_embeddings().parameters() + + # make setter and getter for vec + @property + def vec(self): + # should we get params instead + # create vector from token embeds + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + # stack the tokens along batch axis adding that axis + new_vector = torch.stack( + [token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids], + dim=0 + ) + return new_vector + + @vec.setter + def vec(self, new_vector): + # shape is (1, 768) for SD 1.5 for 1 token + token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + for i in range(new_vector.shape[0]): + # apply the weights to the placeholder tokens while preserving gradient + token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone() + x = 1 + + def save(self, filename): + # todo check to see how to get the vector out of the embedding + + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": 0, + # todo get these + "sd_checkpoint": None, + "sd_checkpoint_name": None, + "notes": None, + } + if filename.endswith('.pt'): + torch.save(embedding_data, filename) + elif filename.endswith('.bin'): + torch.save(embedding_data, filename) + elif filename.endswith('.safetensors'): + # save the embedding as a safetensors file + state_dict = {"emb_params": self.vec} + # add all embedding data (except string_to_param), to metadata + metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) + metadata["string_to_param"] = {"*": "emb_params"} + save_meta = get_meta_for_safetensors(metadata, name=self.name) + save_file(state_dict, filename, metadata=save_meta) + + def load_embedding_from_file(self, file_path, device): + # full path + path = os.path.realpath(file_path) + filename = os.path.basename(path) + name, ext = os.path.splitext(filename) + ext = ext.upper() + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + + if ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + # rebuild the embedding from the safetensors file if it has it + tensors = {} + with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + tensors[k] = f.get_tensor(k) + # data = safetensors.torch.load_file(path, device="cpu") + if metadata and 'string_to_param' in metadata and 'emb_params' in tensors: + # our format + def try_json(v): + try: + return json.loads(v) + except: + return v + + data = {k: try_json(v) for k, v in metadata.items()} + data['string_to_param'] = {'*': tensors['emb_params']} + else: + # old format + data = tensors + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, + '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception( + f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7db5e05b..4e7c1141 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -520,6 +520,11 @@ class StableDiffusion: noise_pred_text - noise_pred_uncond ) + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + return noise_pred # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 From d298240ceccf9270b539655bf17a4a5a60fd6c4b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 23 Aug 2023 13:26:28 -0600 Subject: [PATCH 2/3] Tied in ant tested TI script --- .../TextualInversionTrainer.py | 45 ++++-------------- jobs/process/BaseSDTrainProcess.py | 29 ++++++------ toolkit/config_modules.py | 5 +- toolkit/data_loader.py | 8 ++++ toolkit/embedding.py | 47 ++++++++++++++++--- toolkit/stable_diffusion_model.py | 13 ++++- 6 files changed, 89 insertions(+), 58 deletions(-) diff --git a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py index 31da8402..9eb6e364 100644 --- a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py +++ b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py @@ -52,41 +52,14 @@ class TextualInversionTrainer(BaseSDTrainProcess): # very loosely based on this. very loosely # ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py - conditioned_prompts = [] + # make sure the embedding is in the prompts + conditioned_prompts = [self.embedding.inject_embedding_to_prompt( + x, + expand_token=True, + add_if_not_present=True, + ) for x in prompts] - for prompt in prompts: - # replace our name with the embedding - if self.embed_config.trigger in prompt: - # if the trigger is a part of the prompt, replace it with the token ids - prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) - if self.name in prompt: - # if the name is in the prompt, replace it with the trigger - prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) - if "[name]" in prompt: - # in [name] in prompt, replace it with the trigger - prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) - if self.embedding.get_embedding_string() not in prompt: - # add it to the beginning of the prompt - prompt = self.embedding.get_embedding_string() + " " + prompt - - conditioned_prompts.append(prompt) - - # # get embedding ids - # embedding_ids_list = [self.sd.tokenizer( - # text, - # padding="max_length", - # truncation=True, - # max_length=self.sd.tokenizer.model_max_length, - # return_tensors="pt", - # ).input_ids[0] for text in conditioned_prompts] - - # hidden_states = [] - # for embedding_ids, img in zip(embedding_ids_list, imgs): - # hidden_state = { - # "input_ids": embedding_ids, - # "pixel_values": img - # } - # hidden_states.append(hidden_state) + batch_size = imgs.shape[0] dtype = get_torch_dtype(self.train_config.dtype) imgs = imgs.to(self.device_torch, dtype=dtype) @@ -100,14 +73,14 @@ class TextualInversionTrainer(BaseSDTrainProcess): self.train_config.max_denoising_steps, device=self.device_torch ) - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) timesteps = timesteps.long() # get noise noise = self.sd.get_latent_noise( pixel_height=imgs.shape[2], pixel_width=imgs.shape[3], - batch_size=self.train_config.batch_size, + batch_size=batch_size, noise_offset=self.train_config.noise_offset ).to(self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index ae269e52..838bd8dc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess): prompt = sample_config.prompts[i] # add embedding if there is one + # note: diffusers will automatically expand the trigger to the number of added tokens + # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here if self.embedding is not None: - # replace our name with the embedding - if self.embed_config.trigger in prompt: - # if the trigger is a part of the prompt, replace it with the token ids - prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string()) - if self.name in prompt: - # if the name is in the prompt, replace it with the trigger - prompt = prompt.replace(self.name, self.embedding.get_embedding_string()) - if "[name]" in prompt: - # in [name] in prompt, replace it with the trigger - prompt = prompt.replace("[name]", self.embedding.get_embedding_string()) - if self.embedding.get_embedding_string() not in prompt: - # add it to the beginning of the prompt - prompt = self.embedding.get_embedding_string() + " " + prompt + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + ) gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt @@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.network.multiplier = prev_multiplier elif self.embedding is not None: + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + file_path = os.path.splitext(file_path)[0] + ".pt" self.embedding.save(file_path) else: self.sd.save( @@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess): def before_dataset_load(self): pass - def hook_train_loop(self, batch=None): + def hook_train_loop(self, batch): # return loss return 0.0 @@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + # resume state from embedding + self.step_num = self.embedding.step + # set trainable params params = self.embedding.get_trainable_params() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d79eeb53..46e22fda 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -54,9 +54,10 @@ class EmbeddingConfig: def __init__(self, **kwargs): self.trigger = kwargs.get('trigger', 'custom_embedding') self.tokens = kwargs.get('tokens', 4) - self.init_words = kwargs.get('init_phrase', '*') + self.init_words = kwargs.get('init_words', '*') self.save_format = kwargs.get('save_format', 'safetensors') + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') @@ -75,6 +76,7 @@ class TrainConfig: self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) + self.weight_jitter = kwargs.get('weight_jitter', 0.0) class ModelConfig: @@ -165,6 +167,7 @@ class DatasetConfig: self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) self.scale: float = kwargs.get('scale', 1.0) + self.buckets: bool = kwargs.get('buckets', False) class GenerateImageConfig: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index e399d1be..2ccf6890 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -14,6 +14,13 @@ import albumentations as A from toolkit.config_modules import DatasetConfig from toolkit.dataloader_mixins import CaptionMixin +BUCKET_STEPS = 64 + +def get_bucket_sizes_for_resolution(resolution: int) -> List[int]: + # make sure resolution is divisible by 8 + if resolution % 8 != 0: + resolution = resolution - (resolution % 8) + class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): @@ -357,6 +364,7 @@ class AiToolkitDataset(Dataset, CaptionMixin): def get_dataloader_from_datasets(dataset_options, batch_size=1): + # TODO do bucketing if dataset_options is None or len(dataset_options) == 0: return None diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 40e01250..3eb4483a 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -25,7 +25,9 @@ class Embedding: ): self.name = embed_config.trigger self.sd = sd + self.trigger = embed_config.trigger self.embed_config = embed_config + self.step = 0 # setup our embedding # Add the placeholder token in tokenizer placeholder_tokens = [self.embed_config.trigger] @@ -64,10 +66,7 @@ class Embedding: for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids): token_embeds[token_id] = token_embeds[initializer_token_id].clone() - # this doesnt seem to be used again - self.token_embeds = token_embeds - - # replace "[name] with this. This triggers it in the text encoder + # replace "[name] with this. on training. This is automatically generated in pipeline on inference self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids)) # returns the string to have in the prompt to trigger the embedding @@ -86,7 +85,7 @@ class Embedding: token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data # stack the tokens along batch axis adding that axis new_vector = torch.stack( - [token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids], + [token_embeds[token_id] for token_id in self.placeholder_token_ids], dim=0 ) return new_vector @@ -100,6 +99,39 @@ class Embedding: token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone() x = 1 + # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc + # however, on training we don't use that pipeline, so we have to do it ourselves + def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens] + + replace_with = self.embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + def save(self, filename): # todo check to see how to get the vector out of the embedding @@ -107,7 +139,7 @@ class Embedding: "string_to_token": {"*": 265}, "string_to_param": {"*": self.vec}, "name": self.name, - "step": 0, + "step": self.step, # todo get these "sd_checkpoint": None, "sd_checkpoint_name": None, @@ -182,4 +214,7 @@ class Embedding: raise Exception( f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + if 'step' in data: + self.step = int(data['step']) + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4e7c1141..e4667ee6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -435,7 +435,7 @@ class StableDiffusion: text_embeddings = train_tools.concat_prompt_embeddings( unconditional_embeddings, # negative embedding conditional_embeddings, # positive embedding - latents.shape[0], # batch size + 1, # batch size ) elif text_embeddings is None and conditional_embeddings is not None: # not doing cfg @@ -506,6 +506,17 @@ class StableDiffusion: latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor): + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2) + else: + raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + # predict the noise residual noise_pred = self.unet( latent_model_input, From f200cf36c5d9a33b5e412fa87c8e031363492788 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 23 Aug 2023 13:30:29 -0600 Subject: [PATCH 3/3] Added train example to ti --- .../config/train.example.yaml | 123 ++++++++---------- 1 file changed, 54 insertions(+), 69 deletions(-) diff --git a/extensions_built_in/textual_inversion_trainer/config/train.example.yaml b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml index 8b0f4734..90f3c0b1 100644 --- a/extensions_built_in/textual_inversion_trainer/config/train.example.yaml +++ b/extensions_built_in/textual_inversion_trainer/config/train.example.yaml @@ -1,70 +1,72 @@ --- job: extension config: - name: example_name + name: test_v1 process: - - type: 'image_reference_slider_trainer' - training_folder: "/mnt/Train/out/LoRA" + - type: 'textual_inversion_trainer' + training_folder: "out/TI" device: cuda:0 # for tensorboard logging - log_dir: "/home/jaret/Dev/.tensorboard" - network: - type: "lora" - linear: 8 - linear_alpha: 8 + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_type: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 train: noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" - steps: 5000 - lr: 1e-4 - train_unet: true + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false gradient_checkpointing: true - train_text_encoder: true + train_text_encoder: false optimizer: "adamw" +# optimizer: "prodigy" optimizer_params: weight_decay: 1e-2 lr_scheduler: "constant" max_denoising_steps: 1000 - batch_size: 1 + batch_size: 4 dtype: bf16 xformers: true - skip_first_sample: true - noise_offset: 0.0 + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this model: - name_or_path: "/path/to/model.safetensors" + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" is_v2: false # for v2 models is_xl: false # for SDXL models is_v_pred: false # for v-prediction models (most v2 models) - save: - dtype: float16 # precision to save - save_every: 1000 # save every this many steps - max_step_saves_to_keep: 2 # only affects step counts sample: sampler: "ddpm" # must match train.noise_scheduler sample_every: 100 # sample every this many steps width: 512 height: 512 prompts: - - "photo of a woman with red hair taking a selfie --m -3" - - "photo of a woman with red hair taking a selfie --m -1" - - "photo of a woman with red hair taking a selfie --m 1" - - "photo of a woman with red hair taking a selfie --m 3" - - "close up photo of a man smiling at the camera, in a tank top --m -3" - - "close up photo of a man smiling at the camera, in a tank top--m -1" - - "close up photo of a man smiling at the camera, in a tank top --m 1" - - "close up photo of a man smiling at the camera, in a tank top --m 3" - - "photo of a blonde woman smiling, barista --m -3" - - "photo of a blonde woman smiling, barista --m -1" - - "photo of a blonde woman smiling, barista --m 1" - - "photo of a blonde woman smiling, barista --m 3" - - "photo of a Christina Hendricks --m -1" - - "photo of a Christina Hendricks --m -1" - - "photo of a Christina Hendricks --m 1" - - "photo of a Christina Hendricks --m 3" - - "photo of a Christina Ricci --m -3" - - "photo of a Christina Ricci --m -1" - - "photo of a Christina Ricci --m 1" - - "photo of a Christina Ricci --m 3" - neg: "cartoon, fake, drawing, illustration, cgi, animated, anime" + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" seed: 42 walk_seed: false guidance_scale: 7 @@ -76,32 +78,15 @@ config: use_wandb: false # not supported yet verbose: false - slider: - datasets: - - pair_folder: "/path/to/folder/side/by/side/images" - network_weight: 2.0 - target_class: "" # only used as default if caption txt are not present - size: 512 - - pair_folder: "/path/to/folder/side/by/side/images" - network_weight: 4.0 - target_class: "" # only used as default if caption txt are not present - size: 512 - - -# you can put any information you want here, and it will be saved in the model -# the below is an example. I recommend doing trigger words at a minimum -# in the metadata. The software will include this plus some other information +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically meta: - name: "[name]" # [name] gets replaced with the name above - description: A short description of your model - trigger_words: - - put - - trigger - - words - - here - version: '0.1' - creator: - name: Your Name - email: your@email.com - website: https://yourwebsite.com - any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website