Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see

This commit is contained in:
Jaret Burkett
2023-09-07 13:06:18 -06:00
parent 436bf0c6a3
commit 3feb663a51
10 changed files with 208 additions and 140 deletions

View File

@@ -36,8 +36,6 @@ class ConceptReplacer(BaseSDTrainProcess):
# textual inversion
if self.embedding is not None:
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
@@ -142,13 +140,7 @@ class ConceptReplacer(BaseSDTrainProcess):
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}

View File

@@ -26,11 +26,9 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to(self.device_torch)
# textual inversion
if self.embedding is not None:
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# if self.embedding is not None:
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
# self.sd.text_encoder.train()
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
@@ -103,13 +101,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}

View File

@@ -5,7 +5,7 @@ from collections import OrderedDict
import os
from typing import Union
from lycoris.config import PRESET
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
@@ -126,7 +126,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.embedding = None
self.embedding: Union[Embedding, None] = None
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -261,13 +261,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta
metadata=save_meta,
extra_state_dict=embedding_dict
)
self.network.multiplier = prev_multiplier
# if we have an embedding as well, pair it with the network
elif self.embedding is not None:
# for combo, above will get it
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
@@ -330,16 +336,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
def load_weights(self, path):
if self.network is not None:
self.network.load_weights(path)
extra_weights = self.network.load_weights(path)
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
return extra_weights
else:
print("load_weights not implemented for non-network models")
return None
def process_general_training_batch(self, batch):
with torch.no_grad():
@@ -479,9 +486,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
if is_lycoris:
preset = PRESET['full']
# NetworkClass.apply_preset(preset)
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
self.network = NetworkClass(
text_encoder=text_encoder,
@@ -533,12 +540,25 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.is_normalizing = self.network_config.normalize
latest_save_path = self.get_latest_save_path()
extra_weights = None
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
extra_weights = self.load_weights(latest_save_path)
self.network.multiplier = 1.0
if self.embed_config is not None:
# we are doing embedding training as well
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config,
state_dict=extra_weights
)
params.append({
'params': self.embedding.get_trainable_params(),
'lr': self.train_config.embedding_lr
})
flush()
elif self.embed_config is not None:
self.embedding = Embedding(

View File

@@ -16,4 +16,7 @@ toml
albumentations
pydantic
omegaconf
k-diffusion
k-diffusion
open_clip_torch
timm
prodigyopt

View File

@@ -1,6 +1,10 @@
from typing import Type, List, Union
from typing import Type, List, Union, TypedDict
class BucketResolution(TypedDict):
width: int
height: int
BucketResolution = Type[{"width": int, "height": int}]
# resolutions SDXL was trained on with a 1024x1024 base resolution
resolutions_1024: List[BucketResolution] = [

View File

@@ -72,6 +72,7 @@ class TrainConfig:
self.lr = kwargs.get('lr', 1e-6)
self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')

View File

@@ -3,6 +3,7 @@ import os
import random
from typing import TYPE_CHECKING, List, Dict, Union
from toolkit.buckets import get_bucket_for_image_size
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image
@@ -102,54 +103,21 @@ class BucketsMixin:
width = file_item.crop_width
height = file_item.crop_height
# determine new resolution to have the same number of pixels
current_pixels = width * height
if current_pixels == total_pixels:
file_item.scale_to_width = width
file_item.scale_to_height = height
file_item.crop_width = width
file_item.crop_height = height
new_width = width
new_height = height
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution)
# set the scaling height and with to match smallest size, and keep aspect ratio
if width > height:
file_item.scale_height = bucket_resolution["height"]
file_item.scale_width = int(width * (bucket_resolution["height"] / height))
else:
file_item.scale_width = bucket_resolution["width"]
file_item.scale_height = int(height * (bucket_resolution["width"] / width))
aspect_ratio = width / height
new_height = int(math.sqrt(total_pixels / aspect_ratio))
new_width = int(aspect_ratio * new_height)
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
# increase smallest one to be divisible by bucket_tolerance and increase the other to match
if new_width < new_height:
# increase width
if new_width % bucket_tolerance != 0:
crop_amount = new_width % bucket_tolerance
new_width = new_width + (bucket_tolerance - crop_amount)
new_height = int(new_width / aspect_ratio)
else:
# increase height
if new_height % bucket_tolerance != 0:
crop_amount = new_height % bucket_tolerance
new_height = new_height + (bucket_tolerance - crop_amount)
new_width = int(aspect_ratio * new_height)
# Ensure that the total number of pixels remains the same.
# assert new_width * new_height == total_pixels
file_item.scale_to_width = new_width
file_item.scale_to_height = new_height
file_item.crop_width = new_width
file_item.crop_height = new_height
# make sure it is divisible by bucket_tolerance, decrease if not
if new_width % bucket_tolerance != 0:
crop_amount = new_width % bucket_tolerance
file_item.crop_width = new_width - crop_amount
else:
file_item.crop_width = new_width
if new_height % bucket_tolerance != 0:
crop_amount = new_height % bucket_tolerance
file_item.crop_height = new_height - crop_amount
else:
file_item.crop_height = new_height
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
# check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}'

View File

@@ -21,7 +21,8 @@ class Embedding:
def __init__(
self,
sd: 'StableDiffusion',
embed_config: 'EmbeddingConfig'
embed_config: 'EmbeddingConfig',
state_dict: OrderedDict = None,
):
self.name = embed_config.trigger
self.sd = sd
@@ -38,74 +39,112 @@ class Embedding:
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# handle dual tokenizer
self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer]
self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [
self.sd.text_encoder]
# Convert the initializer_token, placeholder_token to ids
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
self.placeholder_token_ids = []
self.embedding_tokens = []
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
# todo SDXL has 2 text encoders, need to do both for all of this
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
# Convert the initializer_token, placeholder_token to ids
init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
self.placeholder_token_ids.append(placeholder_token_ids)
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# returns the string to have in the prompt to trigger the embedding
def get_embedding_string(self):
return self.embedding_tokens
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
# backup text encoder embeddings
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
def get_trainable_params(self):
# todo only get this one as we could have more than one
return self.sd.text_encoder.get_input_embeddings().parameters()
params = []
for text_encoder in self.text_encoder_list:
params += text_encoder.get_input_embeddings().parameters()
return params
# make setter and getter for vec
@property
def vec(self):
def _get_vec(self, text_encoder_idx=0):
# should we get params instead
# create vector from token embeds
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
[token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]],
dim=0
)
return new_vector
@vec.setter
def vec(self, new_vector):
def _set_vec(self, new_vector, text_encoder_idx=0):
# shape is (1, 768) for SD 1.5 for 1 token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
for i in range(new_vector.shape[0]):
# apply the weights to the placeholder tokens while preserving gradient
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
# make setter and getter for vec
@property
def vec(self):
return self._get_vec(0)
@vec.setter
def vec(self, new_vector):
self._set_vec(new_vector, 0)
@property
def vec2(self):
return self._get_vec(1)
@vec2.setter
def vec2(self, new_vector):
self._set_vec(new_vector, 1)
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens]
replace_with = self.embedding_tokens if expand_token else self.trigger
replace_with = embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
@@ -132,6 +171,17 @@ class Embedding:
return output_prompt
def state_dict(self):
if self.sd.is_xl:
state_dict = OrderedDict()
state_dict['clip_l'] = self.vec
state_dict['clip_g'] = self.vec2
else:
state_dict = OrderedDict()
state_dict['emb_params'] = self.vec
return state_dict
def save(self, filename):
# todo check to see how to get the vector out of the embedding
@@ -145,13 +195,14 @@ class Embedding:
"sd_checkpoint_name": None,
"notes": None,
}
# TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl
if filename.endswith('.pt'):
torch.save(embedding_data, filename)
elif filename.endswith('.bin'):
torch.save(embedding_data, filename)
elif filename.endswith('.safetensors'):
# save the embedding as a safetensors file
state_dict = {"emb_params": self.vec}
state_dict = self.state_dict()
# add all embedding data (except string_to_param), to metadata
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
metadata["string_to_param"] = {"*": "emb_params"}
@@ -163,6 +214,7 @@ class Embedding:
path = os.path.realpath(file_path)
filename = os.path.basename(path)
name, ext = os.path.splitext(filename)
tensors = {}
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
_, second_ext = os.path.splitext(name)
@@ -170,10 +222,12 @@ class Embedding:
return
if ext in ['.BIN', '.PT']:
# todo check this
if self.sd.is_xl:
raise Exception("XL not supported yet for bin, pt")
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
# rebuild the embedding from the safetensors file if it has it
tensors = {}
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
@@ -217,4 +271,8 @@ class Embedding:
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)
if self.sd.is_xl:
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
else:
self.vec = emb.detach().to(device, dtype=torch.float32)

View File

@@ -228,9 +228,15 @@ class ToolkitNetworkMixin:
return keymap
def save_weights(self: Network, file, dtype=torch.float16, metadata=None):
def save_weights(
self: Network,
file, dtype=torch.float16,
metadata=None,
extra_state_dict: Optional[OrderedDict] = None
):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
@@ -249,6 +255,13 @@ class ToolkitNetworkMixin:
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if extra_state_dict is not None:
# add extra items to state dict
for key in list(extra_state_dict.keys()):
v = extra_state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_dict[key] = v
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
@@ -275,8 +288,21 @@ class ToolkitNetworkMixin:
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
# extract extra items from state dict
current_state_dict = self.state_dict()
extra_dict = OrderedDict()
to_delete = []
for key in list(load_sd.keys()):
if key not in current_state_dict:
extra_dict[key] = load_sd[key]
to_delete.append(key)
for key in to_delete:
del load_sd[key]
info = self.load_state_dict(load_sd, False)
return info
if len(extra_dict.keys()) == 0:
extra_dict = None
return extra_dict
@property
def multiplier(self) -> Union[float, List[float]]:

View File

@@ -442,21 +442,25 @@ class StableDiffusion:
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor):
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
if self.is_xl:
prompt_ids = train_tools.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return prompt_ids
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
# just do it without any cropping nonsense
target_size = (height, width)
original_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
batch_time_ids = torch.cat(
[add_time_ids for _ in range(bs)]
)
return batch_time_ids
else:
return None
@@ -682,7 +686,7 @@ class StableDiffusion:
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
latents = latents / 0.18215
latents = latents / self.vae.config['scaling_factor']
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)