WIP creating textual inversion training script

This commit is contained in:
Jaret Burkett
2023-08-22 21:02:38 -06:00
parent 36ba08d3fa
commit 2e6c55c720
9 changed files with 746 additions and 6 deletions

View File

@@ -0,0 +1,179 @@
import copy
import random
from collections import OrderedDict
import os
from contextlib import nullcontext
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.config_modules import ReferenceDatasetConfig
from toolkit.data_loader import PairedImageDataset, ImageDataset
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
from toolkit.basic import value_map
def flush():
torch.cuda.empty_cache()
gc.collect()
class TextualInversionTrainer(BaseSDTrainProcess):
sd: StableDiffusion
data_loader: DataLoader = None
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
pass
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
pass
def hook_train_loop(self, batch):
with torch.no_grad():
imgs, prompts = batch
# very loosely based on this. very loosely
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
conditioned_prompts = []
for prompt in prompts:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
conditioned_prompts.append(prompt)
# # get embedding ids
# embedding_ids_list = [self.sd.tokenizer(
# text,
# padding="max_length",
# truncation=True,
# max_length=self.sd.tokenizer.model_max_length,
# return_tensors="pt",
# ).input_ids[0] for text in conditioned_prompts]
# hidden_states = []
# for embedding_ids, img in zip(embedding_ids_list, imgs):
# hidden_state = {
# "input_ids": embedding_ids,
# "pixel_values": img
# }
# hidden_states.append(hidden_state)
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
flush()
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# text encoding
embedding_list = []
# embed the prompts
for prompt in conditioned_prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
noise = noise.to(self.device_torch, dtype=dtype)
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
loss_dict = OrderedDict(
{'loss': loss.item()}
)
return loss_dict
# end hook_train_loop

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# We make a subclass of Extension
class OffsetSliderTrainer(Extension):
# uid must be unique, it is how the extension is identified
uid = "textual_inversion_trainer"
# name is the name of the extension for printing
name = "Textual Inversion Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .TextualInversionTrainer import TextualInversionTrainer
return TextualInversionTrainer
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
OffsetSliderTrainer
]

View File

@@ -0,0 +1,107 @@
---
job: extension
config:
name: example_name
process:
- type: 'image_reference_slider_trainer'
training_folder: "/mnt/Train/out/LoRA"
device: cuda:0
# for tensorboard logging
log_dir: "/home/jaret/Dev/.tensorboard"
network:
type: "lora"
linear: 8
linear_alpha: 8
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 5000
lr: 1e-4
train_unet: true
gradient_checkpointing: true
train_text_encoder: true
optimizer: "adamw"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 1
dtype: bf16
xformers: true
skip_first_sample: true
noise_offset: 0.0
model:
name_or_path: "/path/to/model.safetensors"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
save:
dtype: float16 # precision to save
save_every: 1000 # save every this many steps
max_step_saves_to_keep: 2 # only affects step counts
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of a woman with red hair taking a selfie --m -3"
- "photo of a woman with red hair taking a selfie --m -1"
- "photo of a woman with red hair taking a selfie --m 1"
- "photo of a woman with red hair taking a selfie --m 3"
- "close up photo of a man smiling at the camera, in a tank top --m -3"
- "close up photo of a man smiling at the camera, in a tank top--m -1"
- "close up photo of a man smiling at the camera, in a tank top --m 1"
- "close up photo of a man smiling at the camera, in a tank top --m 3"
- "photo of a blonde woman smiling, barista --m -3"
- "photo of a blonde woman smiling, barista --m -1"
- "photo of a blonde woman smiling, barista --m 1"
- "photo of a blonde woman smiling, barista --m 3"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m 1"
- "photo of a Christina Hendricks --m 3"
- "photo of a Christina Ricci --m -3"
- "photo of a Christina Ricci --m -1"
- "photo of a Christina Ricci --m 1"
- "photo of a Christina Ricci --m 3"
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
slider:
datasets:
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 2.0
target_class: "" # only used as default if caption txt are not present
size: 512
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 4.0
target_class: "" # only used as default if caption txt are not present
size: 512
# you can put any information you want here, and it will be saved in the model
# the below is an example. I recommend doing trigger words at a minimum
# in the metadata. The software will include this plus some other information
meta:
name: "[name]" # [name] gets replaced with the name above
description: A short description of your model
trigger_words:
- put
- trigger
- words
- here
version: '0.1'
creator:
name: Your Name
email: your@email.com
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.

View File

@@ -5,6 +5,8 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -20,7 +22,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig
def flush():
@@ -30,6 +32,7 @@ def flush():
class BaseSDTrainProcess(BaseTrainProcess):
sd: StableDiffusion
embedding: Union[Embedding, None] = None
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
super().__init__(process_id, job, config)
@@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
raw_datasets = self.get_conf('datasets', None)
self.datasets = None
if raw_datasets is not None and len(raw_datasets) > 0:
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
@@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network = None
self.embedding = None
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -89,8 +103,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
# add embedding if there is one
if self.embedding is not None:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
gen_img_config_list.append(GenerateImageConfig(
prompt=sample_config.prompts[i], # it will autoparse the prompt
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
@@ -175,6 +207,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
metadata=save_meta
)
self.network.multiplier = prev_multiplier
elif self.embedding is not None:
self.embedding.save(file_path)
else:
self.sd.save(
file_path,
@@ -197,6 +231,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def before_dataset_load(self):
pass
def hook_train_loop(self, batch=None):
# return loss
return 0.0
@@ -208,6 +245,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{self.job.name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
@@ -230,11 +272,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
def run(self):
# run base process run
BaseTrainProcess.run(self)
### HOOk ###
self.before_dataset_load()
# load datasets if passed in the root process
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
### HOOK ###
self.hook_before_model_load()
# run base sd process run
self.sd.load_model()
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -303,7 +355,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
elif self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path()
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# set trainable params
params = self.embedding.get_trainable_params()
else:
params = []

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional
from typing import List, Optional, Literal
import random
@@ -50,6 +50,13 @@ class NetworkConfig:
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
class EmbeddingConfig:
def __init__(self, **kwargs):
self.trigger = kwargs.get('trigger', 'custom_embedding')
self.tokens = kwargs.get('tokens', 4)
self.init_words = kwargs.get('init_phrase', '*')
self.save_format = kwargs.get('save_format', 'safetensors')
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
@@ -146,6 +153,20 @@ class SliderConfig:
self.targets.append(target)
class DatasetConfig:
caption_type: Literal["txt", "caption"] = 'txt'
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
self.folder_path: str = kwargs.get('folder_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.caption_type: str = kwargs.get('caption_type', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
self.scale: float = kwargs.get('scale', 1.0)
class GenerateImageConfig:
def __init__(
self,

View File

@@ -1,23 +1,33 @@
import os
import random
from typing import List
import cv2
import numpy as np
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin
class ImageDataset(Dataset):
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
self.config = config
self.name = self.get_config('name', 'dataset')
self.path = self.get_config('path', required=True)
self.scale = self.get_config('scale', 1)
self.random_scale = self.get_config('random_scale', False)
self.include_prompt = self.get_config('include_prompt', False)
self.default_prompt = self.get_config('default_prompt', '')
if self.include_prompt:
self.caption_type = self.get_config('caption_type', 'txt')
else:
self.caption_type = None
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
@@ -81,7 +91,11 @@ class ImageDataset(Dataset):
img = self.transform(img)
return img
if self.include_prompt:
prompt = self.get_caption_item(index)
return img, prompt
else:
return img
class Augments:
@@ -268,3 +282,101 @@ class PairedImageDataset(Dataset):
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(Dataset, CaptionMixin):
def __init__(self, dataset_config: 'DatasetConfig'):
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
# get the file list
self.file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# this might take a while
print(f" - Preprocessing image dimensions")
new_file_list = []
bad_count = 0
for file in tqdm(self.file_list):
img = Image.open(file)
if int(min(img.size) * self.scale) >= self.resolution:
new_file_list.append(file)
else:
bad_count += 1
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
if self.caption_type is not None:
prompt = self.get_caption_item(index)
return img, prompt
else:
return img
def get_dataloader_from_datasets(dataset_options, batch_size=1):
if dataset_options is None or len(dataset_options) == 0:
return None
datasets = []
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
else:
config = DatasetConfig(**dataset_option)
if config.type == 'image':
dataset = AiToolkitDataset(config)
datasets.append(dataset)
else:
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
return data_loader

View File

@@ -0,0 +1,43 @@
import os
class CaptionMixin:
def get_caption_item(self, index):
if not hasattr(self, 'caption_type'):
raise Exception('caption_type not found on class instance')
if not hasattr(self, 'file_list'):
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
prompt_path = path_no_ext + '.txt'
if not os.path.exists(prompt_path):
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = ''
# get default_prompt if it exists on the class instance
if hasattr(self, 'default_prompt'):
prompt = self.default_prompt
if hasattr(self, 'default_caption'):
prompt = self.default_caption
return prompt

185
toolkit/embedding.py Normal file
View File

@@ -0,0 +1,185 @@
import json
import os
from collections import OrderedDict
import safetensors
import torch
from typing import TYPE_CHECKING
from safetensors.torch import save_file
from toolkit.metadata import get_meta_for_safetensors
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import EmbeddingConfig
# this is a frankenstein mix of automatic1111 and my own code
class Embedding:
def __init__(
self,
sd: 'StableDiffusion',
embed_config: 'EmbeddingConfig'
):
self.name = embed_config.trigger
self.sd = sd
self.embed_config = embed_config
# setup our embedding
# Add the placeholder token in tokenizer
placeholder_tokens = [self.embed_config.trigger]
# add dummy tokens for multi-vector
additional_tokens = []
for i in range(1, self.embed_config.tokens):
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
# todo SDXL has 2 text encoders, need to do both for all of this
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# this doesnt seem to be used again
self.token_embeds = token_embeds
# replace "[name] with this. This triggers it in the text encoder
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
# returns the string to have in the prompt to trigger the embedding
def get_embedding_string(self):
return self.embedding_tokens
def get_trainable_params(self):
# todo only get this one as we could have more than one
return self.sd.text_encoder.get_input_embeddings().parameters()
# make setter and getter for vec
@property
def vec(self):
# should we get params instead
# create vector from token embeds
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
dim=0
)
return new_vector
@vec.setter
def vec(self, new_vector):
# shape is (1, 768) for SD 1.5 for 1 token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
for i in range(new_vector.shape[0]):
# apply the weights to the placeholder tokens while preserving gradient
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
def save(self, filename):
# todo check to see how to get the vector out of the embedding
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": 0,
# todo get these
"sd_checkpoint": None,
"sd_checkpoint_name": None,
"notes": None,
}
if filename.endswith('.pt'):
torch.save(embedding_data, filename)
elif filename.endswith('.bin'):
torch.save(embedding_data, filename)
elif filename.endswith('.safetensors'):
# save the embedding as a safetensors file
state_dict = {"emb_params": self.vec}
# add all embedding data (except string_to_param), to metadata
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
metadata["string_to_param"] = {"*": "emb_params"}
save_meta = get_meta_for_safetensors(metadata, name=self.name)
save_file(state_dict, filename, metadata=save_meta)
def load_embedding_from_file(self, file_path, device):
# full path
path = os.path.realpath(file_path)
filename = os.path.basename(path)
name, ext = os.path.splitext(filename)
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
_, second_ext = os.path.splitext(name)
if second_ext.upper() == '.PREVIEW':
return
if ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
# rebuild the embedding from the safetensors file if it has it
tensors = {}
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
tensors[k] = f.get_tensor(k)
# data = safetensors.torch.load_file(path, device="cpu")
if metadata and 'string_to_param' in metadata and 'emb_params' in tensors:
# our format
def try_json(v):
try:
return json.loads(v)
except:
return v
data = {k: try_json(v) for k, v in metadata.items()}
data['string_to_param'] = {'*': tensors['emb_params']}
else:
# old format
data = tensors
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
self.vec = emb.detach().to(device, dtype=torch.float32)

View File

@@ -520,6 +520,11 @@ class StableDiffusion:
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
return noise_pred
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746