Merge pull request #7 from ostris/textual_inversion

Textual inversion training
This commit is contained in:
Jaret Burkett
2023-08-23 13:31:37 -06:00
committed by GitHub
9 changed files with 764 additions and 8 deletions

View File

@@ -0,0 +1,152 @@
import copy
import random
from collections import OrderedDict
import os
from contextlib import nullcontext
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.config_modules import ReferenceDatasetConfig
from toolkit.data_loader import PairedImageDataset, ImageDataset
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
from toolkit.basic import value_map
def flush():
torch.cuda.empty_cache()
gc.collect()
class TextualInversionTrainer(BaseSDTrainProcess):
sd: StableDiffusion
data_loader: DataLoader = None
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
pass
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
pass
def hook_train_loop(self, batch):
with torch.no_grad():
imgs, prompts = batch
# very loosely based on this. very loosely
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
# make sure the embedding is in the prompts
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
x,
expand_token=True,
add_if_not_present=True,
) for x in prompts]
batch_size = imgs.shape[0]
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
flush()
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# text encoding
embedding_list = []
# embed the prompts
for prompt in conditioned_prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
noise = noise.to(self.device_torch, dtype=dtype)
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
loss_dict = OrderedDict(
{'loss': loss.item()}
)
return loss_dict
# end hook_train_loop

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# We make a subclass of Extension
class OffsetSliderTrainer(Extension):
# uid must be unique, it is how the extension is identified
uid = "textual_inversion_trainer"
# name is the name of the extension for printing
name = "Textual Inversion Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .TextualInversionTrainer import TextualInversionTrainer
return TextualInversionTrainer
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
OffsetSliderTrainer
]

View File

@@ -0,0 +1,92 @@
---
job: extension
config:
name: test_v1
process:
- type: 'textual_inversion_trainer'
training_folder: "out/TI"
device: cuda:0
# for tensorboard logging
log_dir: "out/.tensorboard"
embedding:
trigger: "your_trigger_here"
tokens: 12
init_words: "man with short brown hair"
save_format: "safetensors" # 'safetensors' or 'pt'
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_type: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 3000
weight_jitter: 0.0
lr: 5e-5
train_unet: false
gradient_checkpointing: true
train_text_encoder: false
optimizer: "adamw"
# optimizer: "prodigy"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 4
dtype: bf16
xformers: true
min_snr_gamma: 5.0
# skip_first_sample: true
noise_offset: 0.0 # not needed for this
model:
# objective reality v2
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of [trigger] laughing"
- "photo of [trigger] smiling"
- "[trigger] close up"
- "dark scene [trigger] frozen"
- "[trigger] nighttime"
- "a painting of [trigger]"
- "a drawing of [trigger]"
- "a cartoon of [trigger]"
- "[trigger] pixar style"
- "[trigger] costume"
neg: ""
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this
# plus some other information for you automatically
meta:
# [name] gets replaced with the name above
name: "[name]"
# version: '1.0'
# creator:
# name: Your Name
# email: your@gmail.com
# website: https://your.website

View File

@@ -5,6 +5,8 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -20,7 +22,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig
def flush():
@@ -30,6 +32,7 @@ def flush():
class BaseSDTrainProcess(BaseTrainProcess):
sd: StableDiffusion
embedding: Union[Embedding, None] = None
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
super().__init__(process_id, job, config)
@@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
raw_datasets = self.get_conf('datasets', None)
self.datasets = None
if raw_datasets is not None and len(raw_datasets) > 0:
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
@@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network = None
self.embedding = None
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -89,8 +103,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
# add embedding if there is one
# note: diffusers will automatically expand the trigger to the number of added tokens
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
)
gen_img_config_list.append(GenerateImageConfig(
prompt=sample_config.prompts[i], # it will autoparse the prompt
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
@@ -175,6 +199,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
metadata=save_meta
)
self.network.multiplier = prev_multiplier
elif self.embedding is not None:
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(file_path)[0] + ".pt"
self.embedding.save(file_path)
else:
self.sd.save(
file_path,
@@ -197,7 +229,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def hook_train_loop(self, batch=None):
def before_dataset_load(self):
pass
def hook_train_loop(self, batch):
# return loss
return 0.0
@@ -208,6 +243,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{self.job.name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
@@ -230,11 +270,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
def run(self):
# run base process run
BaseTrainProcess.run(self)
### HOOk ###
self.before_dataset_load()
# load datasets if passed in the root process
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
### HOOK ###
self.hook_before_model_load()
# run base sd process run
self.sd.load_model()
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -303,7 +353,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
elif self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path()
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
# set trainable params
params = self.embedding.get_trainable_params()
else:
params = []

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional
from typing import List, Optional, Literal
import random
@@ -50,6 +50,14 @@ class NetworkConfig:
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
class EmbeddingConfig:
def __init__(self, **kwargs):
self.trigger = kwargs.get('trigger', 'custom_embedding')
self.tokens = kwargs.get('tokens', 4)
self.init_words = kwargs.get('init_words', '*')
self.save_format = kwargs.get('save_format', 'safetensors')
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
@@ -68,6 +76,7 @@ class TrainConfig:
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
class ModelConfig:
@@ -146,6 +155,21 @@ class SliderConfig:
self.targets.append(target)
class DatasetConfig:
caption_type: Literal["txt", "caption"] = 'txt'
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
self.folder_path: str = kwargs.get('folder_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.caption_type: str = kwargs.get('caption_type', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
self.scale: float = kwargs.get('scale', 1.0)
self.buckets: bool = kwargs.get('buckets', False)
class GenerateImageConfig:
def __init__(
self,

View File

@@ -1,23 +1,40 @@
import os
import random
from typing import List
import cv2
import numpy as np
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin
class ImageDataset(Dataset):
BUCKET_STEPS = 64
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
# make sure resolution is divisible by 8
if resolution % 8 != 0:
resolution = resolution - (resolution % 8)
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
self.config = config
self.name = self.get_config('name', 'dataset')
self.path = self.get_config('path', required=True)
self.scale = self.get_config('scale', 1)
self.random_scale = self.get_config('random_scale', False)
self.include_prompt = self.get_config('include_prompt', False)
self.default_prompt = self.get_config('default_prompt', '')
if self.include_prompt:
self.caption_type = self.get_config('caption_type', 'txt')
else:
self.caption_type = None
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
@@ -81,7 +98,11 @@ class ImageDataset(Dataset):
img = self.transform(img)
return img
if self.include_prompt:
prompt = self.get_caption_item(index)
return img, prompt
else:
return img
class Augments:
@@ -268,3 +289,102 @@ class PairedImageDataset(Dataset):
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(Dataset, CaptionMixin):
def __init__(self, dataset_config: 'DatasetConfig'):
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
# get the file list
self.file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# this might take a while
print(f" - Preprocessing image dimensions")
new_file_list = []
bad_count = 0
for file in tqdm(self.file_list):
img = Image.open(file)
if int(min(img.size) * self.scale) >= self.resolution:
new_file_list.append(file)
else:
bad_count += 1
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
if self.caption_type is not None:
prompt = self.get_caption_item(index)
return img, prompt
else:
return img
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
if dataset_options is None or len(dataset_options) == 0:
return None
datasets = []
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
else:
config = DatasetConfig(**dataset_option)
if config.type == 'image':
dataset = AiToolkitDataset(config)
datasets.append(dataset)
else:
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
return data_loader

View File

@@ -0,0 +1,43 @@
import os
class CaptionMixin:
def get_caption_item(self, index):
if not hasattr(self, 'caption_type'):
raise Exception('caption_type not found on class instance')
if not hasattr(self, 'file_list'):
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
prompt_path = path_no_ext + '.txt'
if not os.path.exists(prompt_path):
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = ''
# get default_prompt if it exists on the class instance
if hasattr(self, 'default_prompt'):
prompt = self.default_prompt
if hasattr(self, 'default_caption'):
prompt = self.default_caption
return prompt

220
toolkit/embedding.py Normal file
View File

@@ -0,0 +1,220 @@
import json
import os
from collections import OrderedDict
import safetensors
import torch
from typing import TYPE_CHECKING
from safetensors.torch import save_file
from toolkit.metadata import get_meta_for_safetensors
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import EmbeddingConfig
# this is a frankenstein mix of automatic1111 and my own code
class Embedding:
def __init__(
self,
sd: 'StableDiffusion',
embed_config: 'EmbeddingConfig'
):
self.name = embed_config.trigger
self.sd = sd
self.trigger = embed_config.trigger
self.embed_config = embed_config
self.step = 0
# setup our embedding
# Add the placeholder token in tokenizer
placeholder_tokens = [self.embed_config.trigger]
# add dummy tokens for multi-vector
additional_tokens = []
for i in range(1, self.embed_config.tokens):
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
# todo SDXL has 2 text encoders, need to do both for all of this
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
# returns the string to have in the prompt to trigger the embedding
def get_embedding_string(self):
return self.embedding_tokens
def get_trainable_params(self):
# todo only get this one as we could have more than one
return self.sd.text_encoder.get_input_embeddings().parameters()
# make setter and getter for vec
@property
def vec(self):
# should we get params instead
# create vector from token embeds
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
dim=0
)
return new_vector
@vec.setter
def vec(self, new_vector):
# shape is (1, 768) for SD 1.5 for 1 token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
for i in range(new_vector.shape[0]):
# apply the weights to the placeholder tokens while preserving gradient
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
replace_with = self.embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
def save(self, filename):
# todo check to see how to get the vector out of the embedding
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
# todo get these
"sd_checkpoint": None,
"sd_checkpoint_name": None,
"notes": None,
}
if filename.endswith('.pt'):
torch.save(embedding_data, filename)
elif filename.endswith('.bin'):
torch.save(embedding_data, filename)
elif filename.endswith('.safetensors'):
# save the embedding as a safetensors file
state_dict = {"emb_params": self.vec}
# add all embedding data (except string_to_param), to metadata
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
metadata["string_to_param"] = {"*": "emb_params"}
save_meta = get_meta_for_safetensors(metadata, name=self.name)
save_file(state_dict, filename, metadata=save_meta)
def load_embedding_from_file(self, file_path, device):
# full path
path = os.path.realpath(file_path)
filename = os.path.basename(path)
name, ext = os.path.splitext(filename)
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
_, second_ext = os.path.splitext(name)
if second_ext.upper() == '.PREVIEW':
return
if ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
# rebuild the embedding from the safetensors file if it has it
tensors = {}
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
tensors[k] = f.get_tensor(k)
# data = safetensors.torch.load_file(path, device="cpu")
if metadata and 'string_to_param' in metadata and 'emb_params' in tensors:
# our format
def try_json(v):
try:
return json.loads(v)
except:
return v
data = {k: try_json(v) for k, v in metadata.items()}
data['string_to_param'] = {'*': tensors['emb_params']}
else:
# old format
data = tensors
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)

View File

@@ -435,7 +435,7 @@ class StableDiffusion:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
latents.shape[0], # batch size
1, # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
@@ -506,6 +506,17 @@ class StableDiffusion:
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor):
ts_bs = timestep.shape[0]
if ts_bs != latent_model_input.shape[0]:
if ts_bs == 1:
timestep = torch.cat([timestep] * latent_model_input.shape[0])
elif ts_bs * 2 == latent_model_input.shape[0]:
timestep = torch.cat([timestep] * 2)
else:
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
@@ -520,6 +531,11 @@ class StableDiffusion:
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
return noise_pred
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746