diff --git a/extensions_built_in/concept_replacer/ConceptReplacer.py b/extensions_built_in/concept_replacer/ConceptReplacer.py index b451210d..04d4d42d 100644 --- a/extensions_built_in/concept_replacer/ConceptReplacer.py +++ b/extensions_built_in/concept_replacer/ConceptReplacer.py @@ -36,8 +36,6 @@ class ConceptReplacer(BaseSDTrainProcess): # textual inversion if self.embedding is not None: - # keep original embeddings as reference - self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() # set text encoder to train. Not sure if this is necessary but diffusers example did it self.sd.text_encoder.train() @@ -142,13 +140,7 @@ class ConceptReplacer(BaseSDTrainProcess): if self.embedding is not None: # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) - index_no_updates[ - min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False - with torch.no_grad(): - self.sd.text_encoder.get_input_embeddings().weight[ - index_no_updates - ] = self.orig_embeds_params[index_no_updates] + self.embedding.restore_embeddings() loss_dict = OrderedDict( {'loss': loss.item()} diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 350f59bf..1ae27c9a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -26,11 +26,9 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to(self.device_torch) # textual inversion - if self.embedding is not None: - # keep original embeddings as reference - self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() + # if self.embedding is not None: # set text encoder to train. Not sure if this is necessary but diffusers example did it - self.sd.text_encoder.train() + # self.sd.text_encoder.train() def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) @@ -103,13 +101,7 @@ class SDTrainer(BaseSDTrainProcess): if self.embedding is not None: # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) - index_no_updates[ - min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False - with torch.no_grad(): - self.sd.text_encoder.get_input_embeddings().weight[ - index_no_updates - ] = self.orig_embeds_params[index_no_updates] + self.embedding.restore_embeddings() loss_dict = OrderedDict( {'loss': loss.item()} diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e532ef31..d96f9b46 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,7 +5,7 @@ from collections import OrderedDict import os from typing import Union -from lycoris.config import PRESET +# from lycoris.config import PRESET from torch.utils.data import DataLoader from toolkit.data_loader import get_dataloader_from_datasets @@ -126,7 +126,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.embedding = None + self.embedding: Union[Embedding, None] = None def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') @@ -261,13 +261,19 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network_config.normalize: # apply the normalization self.network.apply_stored_normalizer() + + # if we are doing embedding training as well, add that + embedding_dict = self.embedding.state_dict() if self.embedding else None self.network.save_weights( file_path, dtype=get_torch_dtype(self.save_config.dtype), - metadata=save_meta + metadata=save_meta, + extra_state_dict=embedding_dict ) self.network.multiplier = prev_multiplier + # if we have an embedding as well, pair it with the network elif self.embedding is not None: + # for combo, above will get it # set current step self.embedding.step = self.step_num # change filename to pt if that is set @@ -330,16 +336,17 @@ class BaseSDTrainProcess(BaseTrainProcess): def load_weights(self, path): if self.network is not None: - self.network.load_weights(path) + extra_weights = self.network.load_weights(path) meta = load_metadata_from_safetensors(path) # if 'training_info' in Orderdict keys if 'training_info' in meta and 'step' in meta['training_info']: self.step_num = meta['training_info']['step'] self.start_step = self.step_num print(f"Found step {self.step_num} in metadata, starting from there") - + return extra_weights else: print("load_weights not implemented for non-network models") + return None def process_general_training_batch(self, batch): with torch.no_grad(): @@ -479,9 +486,9 @@ class BaseSDTrainProcess(BaseTrainProcess): NetworkClass = LycorisSpecialNetwork is_lycoris = True - if is_lycoris: - preset = PRESET['full'] - # NetworkClass.apply_preset(preset) + # if is_lycoris: + # preset = PRESET['full'] + # NetworkClass.apply_preset(preset) self.network = NetworkClass( text_encoder=text_encoder, @@ -533,12 +540,25 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network.is_normalizing = self.network_config.normalize latest_save_path = self.get_latest_save_path() + extra_weights = None if latest_save_path is not None: self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") self.print(f"Loading from {latest_save_path}") - self.load_weights(latest_save_path) + extra_weights = self.load_weights(latest_save_path) self.network.multiplier = 1.0 + if self.embed_config is not None: + # we are doing embedding training as well + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config, + state_dict=extra_weights + ) + params.append({ + 'params': self.embedding.get_trainable_params(), + 'lr': self.train_config.embedding_lr + }) + flush() elif self.embed_config is not None: self.embedding = Embedding( diff --git a/requirements.txt b/requirements.txt index 2895366b..8c848127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,7 @@ toml albumentations pydantic omegaconf -k-diffusion \ No newline at end of file +k-diffusion +open_clip_torch +timm +prodigyopt \ No newline at end of file diff --git a/toolkit/buckets.py b/toolkit/buckets.py index e7b6b1af..750b7d2d 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -1,6 +1,10 @@ -from typing import Type, List, Union +from typing import Type, List, Union, TypedDict + + +class BucketResolution(TypedDict): + width: int + height: int -BucketResolution = Type[{"width": int, "height": int}] # resolutions SDXL was trained on with a 1024x1024 base resolution resolutions_1024: List[BucketResolution] = [ diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f3ea686b..055a4f7d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -72,6 +72,7 @@ class TrainConfig: self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) + self.embedding_lr = kwargs.get('embedding_lr', self.lr) self.optimizer = kwargs.get('optimizer', 'adamw') self.optimizer_params = kwargs.get('optimizer_params', {}) self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 9aaa3a7c..7ed869b6 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -3,6 +3,7 @@ import os import random from typing import TYPE_CHECKING, List, Dict, Union +from toolkit.buckets import get_bucket_for_image_size from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms from PIL import Image @@ -102,54 +103,21 @@ class BucketsMixin: width = file_item.crop_width height = file_item.crop_height - # determine new resolution to have the same number of pixels - current_pixels = width * height - if current_pixels == total_pixels: - file_item.scale_to_width = width - file_item.scale_to_height = height - file_item.crop_width = width - file_item.crop_height = height - new_width = width - new_height = height + bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution) + + # set the scaling height and with to match smallest size, and keep aspect ratio + if width > height: + file_item.scale_height = bucket_resolution["height"] + file_item.scale_width = int(width * (bucket_resolution["height"] / height)) else: + file_item.scale_width = bucket_resolution["width"] + file_item.scale_height = int(height * (bucket_resolution["width"] / width)) - aspect_ratio = width / height - new_height = int(math.sqrt(total_pixels / aspect_ratio)) - new_width = int(aspect_ratio * new_height) + file_item.crop_height = bucket_resolution["height"] + file_item.crop_width = bucket_resolution["width"] - # increase smallest one to be divisible by bucket_tolerance and increase the other to match - if new_width < new_height: - # increase width - if new_width % bucket_tolerance != 0: - crop_amount = new_width % bucket_tolerance - new_width = new_width + (bucket_tolerance - crop_amount) - new_height = int(new_width / aspect_ratio) - else: - # increase height - if new_height % bucket_tolerance != 0: - crop_amount = new_height % bucket_tolerance - new_height = new_height + (bucket_tolerance - crop_amount) - new_width = int(aspect_ratio * new_height) - - # Ensure that the total number of pixels remains the same. - # assert new_width * new_height == total_pixels - - file_item.scale_to_width = new_width - file_item.scale_to_height = new_height - file_item.crop_width = new_width - file_item.crop_height = new_height - # make sure it is divisible by bucket_tolerance, decrease if not - if new_width % bucket_tolerance != 0: - crop_amount = new_width % bucket_tolerance - file_item.crop_width = new_width - crop_amount - else: - file_item.crop_width = new_width - - if new_height % bucket_tolerance != 0: - crop_amount = new_height % bucket_tolerance - file_item.crop_height = new_height - crop_amount - else: - file_item.crop_height = new_height + new_width = bucket_resolution["width"] + new_height = bucket_resolution["height"] # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 3eb4483a..3bc0a68e 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -21,7 +21,8 @@ class Embedding: def __init__( self, sd: 'StableDiffusion', - embed_config: 'EmbeddingConfig' + embed_config: 'EmbeddingConfig', + state_dict: OrderedDict = None, ): self.name = embed_config.trigger self.sd = sd @@ -38,74 +39,112 @@ class Embedding: additional_tokens.append(f"{self.embed_config.trigger}_{i}") placeholder_tokens += additional_tokens - num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens) - if num_added_tokens != self.embed_config.tokens: - raise ValueError( - f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) + # handle dual tokenizer + self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer] + self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [ + self.sd.text_encoder] - # Convert the initializer_token, placeholder_token to ids - init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) - # if length of token ids is more than number of orm embedding tokens fill with * - if len(init_token_ids) > self.embed_config.tokens: - init_token_ids = init_token_ids[:self.embed_config.tokens] - elif len(init_token_ids) < self.embed_config.tokens: - pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False) - init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) + self.placeholder_token_ids = [] + self.embedding_tokens = [] - self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens) + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.embed_config.tokens: + raise ValueError( + f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) - # Resize the token embeddings as we are adding new special tokens to the tokenizer - # todo SDXL has 2 text encoders, need to do both for all of this - self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer)) + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.embed_config.tokens: + init_token_ids = init_token_ids[:self.embed_config.tokens] + elif len(init_token_ids) < self.embed_config.tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids)) - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data - with torch.no_grad(): - for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids): - token_embeds[token_id] = token_embeds[initializer_token_id].clone() + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) - # replace "[name] with this. on training. This is automatically generated in pipeline on inference - self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids)) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) - # returns the string to have in the prompt to trigger the embedding - def get_embedding_string(self): - return self.embedding_tokens + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + def restore_embeddings(self): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ + min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + with torch.no_grad(): + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] def get_trainable_params(self): - # todo only get this one as we could have more than one - return self.sd.text_encoder.get_input_embeddings().parameters() + params = [] + for text_encoder in self.text_encoder_list: + params += text_encoder.get_input_embeddings().parameters() + return params - # make setter and getter for vec - @property - def vec(self): + def _get_vec(self, text_encoder_idx=0): # should we get params instead # create vector from token embeds - token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data # stack the tokens along batch axis adding that axis new_vector = torch.stack( - [token_embeds[token_id] for token_id in self.placeholder_token_ids], + [token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]], dim=0 ) return new_vector - @vec.setter - def vec(self, new_vector): + def _set_vec(self, new_vector, text_encoder_idx=0): # shape is (1, 768) for SD 1.5 for 1 token - token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data + token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data for i in range(new_vector.shape[0]): # apply the weights to the placeholder tokens while preserving gradient - token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone() - x = 1 + token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone() + + # make setter and getter for vec + @property + def vec(self): + return self._get_vec(0) + + @vec.setter + def vec(self, new_vector): + self._set_vec(new_vector, 0) + + @property + def vec2(self): + return self._get_vec(1) + + @vec2.setter + def vec2(self, new_vector): + self._set_vec(new_vector, 1) # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc # however, on training we don't use that pipeline, so we have to do it ourselves def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): output_prompt = prompt - default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens] + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens] - replace_with = self.embedding_tokens if expand_token else self.trigger + replace_with = embedding_tokens if expand_token else self.trigger if to_replace_list is None: to_replace_list = default_replacements else: @@ -132,6 +171,17 @@ class Embedding: return output_prompt + def state_dict(self): + if self.sd.is_xl: + state_dict = OrderedDict() + state_dict['clip_l'] = self.vec + state_dict['clip_g'] = self.vec2 + else: + state_dict = OrderedDict() + state_dict['emb_params'] = self.vec + + return state_dict + def save(self, filename): # todo check to see how to get the vector out of the embedding @@ -145,13 +195,14 @@ class Embedding: "sd_checkpoint_name": None, "notes": None, } + # TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl if filename.endswith('.pt'): torch.save(embedding_data, filename) elif filename.endswith('.bin'): torch.save(embedding_data, filename) elif filename.endswith('.safetensors'): # save the embedding as a safetensors file - state_dict = {"emb_params": self.vec} + state_dict = self.state_dict() # add all embedding data (except string_to_param), to metadata metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"}) metadata["string_to_param"] = {"*": "emb_params"} @@ -163,6 +214,7 @@ class Embedding: path = os.path.realpath(file_path) filename = os.path.basename(path) name, ext = os.path.splitext(filename) + tensors = {} ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: _, second_ext = os.path.splitext(name) @@ -170,10 +222,12 @@ class Embedding: return if ext in ['.BIN', '.PT']: + # todo check this + if self.sd.is_xl: + raise Exception("XL not supported yet for bin, pt") data = torch.load(path, map_location="cpu") elif ext in ['.SAFETENSORS']: # rebuild the embedding from the safetensors file if it has it - tensors = {} with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f: metadata = f.metadata() for k in f.keys(): @@ -217,4 +271,8 @@ class Embedding: if 'step' in data: self.step = int(data['step']) - self.vec = emb.detach().to(device, dtype=torch.float32) + if self.sd.is_xl: + self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32) + self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32) + else: + self.vec = emb.detach().to(device, dtype=torch.float32) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index beb6596a..75d74c6d 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -228,9 +228,15 @@ class ToolkitNetworkMixin: return keymap - def save_weights(self: Network, file, dtype=torch.float16, metadata=None): + def save_weights( + self: Network, + file, dtype=torch.float16, + metadata=None, + extra_state_dict: Optional[OrderedDict] = None + ): keymap = self.get_keymap() + save_keymap = {} if keymap is not None: for ldm_key, diffusers_key in keymap.items(): @@ -249,6 +255,13 @@ class ToolkitNetworkMixin: save_key = save_keymap[key] if key in save_keymap else key save_dict[save_key] = v + if extra_state_dict is not None: + # add extra items to state dict + for key in list(extra_state_dict.keys()): + v = extra_state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_dict[key] = v + if metadata is None: metadata = OrderedDict() metadata = add_model_hash_to_meta(state_dict, metadata) @@ -275,8 +288,21 @@ class ToolkitNetworkMixin: load_key = keymap[key] if key in keymap else key load_sd[load_key] = value + # extract extra items from state dict + current_state_dict = self.state_dict() + extra_dict = OrderedDict() + to_delete = [] + for key in list(load_sd.keys()): + if key not in current_state_dict: + extra_dict[key] = load_sd[key] + to_delete.append(key) + for key in to_delete: + del load_sd[key] + info = self.load_state_dict(load_sd, False) - return info + if len(extra_dict.keys()) == 0: + extra_dict = None + return extra_dict @property def multiplier(self) -> Union[float, List[float]]: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c591c5de..34ccb396 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -442,21 +442,25 @@ class StableDiffusion: return noise def get_time_ids_from_latents(self, latents: torch.Tensor): - bs, ch, h, w = list(latents.shape) - - height = h * VAE_SCALE_FACTOR - width = w * VAE_SCALE_FACTOR - - dtype = latents.dtype - if self.is_xl: - prompt_ids = train_tools.get_add_time_ids( - height, - width, - dynamic_crops=False, # look into this - dtype=dtype, - ).to(self.device_torch, dtype=dtype) - return prompt_ids + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids else: return None @@ -682,7 +686,7 @@ class StableDiffusion: if self.vae.device == 'cpu': self.vae.to(self.device) latents = latents.to(device, dtype=dtype) - latents = latents / 0.18215 + latents = latents / self.vae.config['scaling_factor'] images = self.vae.decode(latents).sample images = images.to(device, dtype=dtype)