WIP creating textual inversion training script

This commit is contained in:
Jaret Burkett
2023-08-22 21:02:38 -06:00
parent 36ba08d3fa
commit 2e6c55c720
9 changed files with 746 additions and 6 deletions

View File

@@ -0,0 +1,179 @@
import copy
import random
from collections import OrderedDict
import os
from contextlib import nullcontext
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.config_modules import ReferenceDatasetConfig
from toolkit.data_loader import PairedImageDataset, ImageDataset
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
from toolkit.basic import value_map
def flush():
torch.cuda.empty_cache()
gc.collect()
class TextualInversionTrainer(BaseSDTrainProcess):
sd: StableDiffusion
data_loader: DataLoader = None
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
pass
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
pass
def hook_train_loop(self, batch):
with torch.no_grad():
imgs, prompts = batch
# very loosely based on this. very loosely
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
conditioned_prompts = []
for prompt in prompts:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
conditioned_prompts.append(prompt)
# # get embedding ids
# embedding_ids_list = [self.sd.tokenizer(
# text,
# padding="max_length",
# truncation=True,
# max_length=self.sd.tokenizer.model_max_length,
# return_tensors="pt",
# ).input_ids[0] for text in conditioned_prompts]
# hidden_states = []
# for embedding_ids, img in zip(embedding_ids_list, imgs):
# hidden_state = {
# "input_ids": embedding_ids,
# "pixel_values": img
# }
# hidden_states.append(hidden_state)
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
flush()
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# text encoding
embedding_list = []
# embed the prompts
for prompt in conditioned_prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
noise = noise.to(self.device_torch, dtype=dtype)
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
loss_dict = OrderedDict(
{'loss': loss.item()}
)
return loss_dict
# end hook_train_loop

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# We make a subclass of Extension
class OffsetSliderTrainer(Extension):
# uid must be unique, it is how the extension is identified
uid = "textual_inversion_trainer"
# name is the name of the extension for printing
name = "Textual Inversion Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .TextualInversionTrainer import TextualInversionTrainer
return TextualInversionTrainer
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
OffsetSliderTrainer
]

View File

@@ -0,0 +1,107 @@
---
job: extension
config:
name: example_name
process:
- type: 'image_reference_slider_trainer'
training_folder: "/mnt/Train/out/LoRA"
device: cuda:0
# for tensorboard logging
log_dir: "/home/jaret/Dev/.tensorboard"
network:
type: "lora"
linear: 8
linear_alpha: 8
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 5000
lr: 1e-4
train_unet: true
gradient_checkpointing: true
train_text_encoder: true
optimizer: "adamw"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 1
dtype: bf16
xformers: true
skip_first_sample: true
noise_offset: 0.0
model:
name_or_path: "/path/to/model.safetensors"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
save:
dtype: float16 # precision to save
save_every: 1000 # save every this many steps
max_step_saves_to_keep: 2 # only affects step counts
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of a woman with red hair taking a selfie --m -3"
- "photo of a woman with red hair taking a selfie --m -1"
- "photo of a woman with red hair taking a selfie --m 1"
- "photo of a woman with red hair taking a selfie --m 3"
- "close up photo of a man smiling at the camera, in a tank top --m -3"
- "close up photo of a man smiling at the camera, in a tank top--m -1"
- "close up photo of a man smiling at the camera, in a tank top --m 1"
- "close up photo of a man smiling at the camera, in a tank top --m 3"
- "photo of a blonde woman smiling, barista --m -3"
- "photo of a blonde woman smiling, barista --m -1"
- "photo of a blonde woman smiling, barista --m 1"
- "photo of a blonde woman smiling, barista --m 3"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m 1"
- "photo of a Christina Hendricks --m 3"
- "photo of a Christina Ricci --m -3"
- "photo of a Christina Ricci --m -1"
- "photo of a Christina Ricci --m 1"
- "photo of a Christina Ricci --m 3"
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
slider:
datasets:
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 2.0
target_class: "" # only used as default if caption txt are not present
size: 512
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 4.0
target_class: "" # only used as default if caption txt are not present
size: 512
# you can put any information you want here, and it will be saved in the model
# the below is an example. I recommend doing trigger words at a minimum
# in the metadata. The software will include this plus some other information
meta:
name: "[name]" # [name] gets replaced with the name above
description: A short description of your model
trigger_words:
- put
- trigger
- words
- here
version: '0.1'
creator:
name: Your Name
email: your@email.com
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.