mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see
This commit is contained in:
@@ -36,8 +36,6 @@ class ConceptReplacer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# textual inversion
|
# textual inversion
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
# keep original embeddings as reference
|
|
||||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
|
||||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||||
self.sd.text_encoder.train()
|
self.sd.text_encoder.train()
|
||||||
|
|
||||||
@@ -142,13 +140,7 @@ class ConceptReplacer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
self.embedding.restore_embeddings()
|
||||||
index_no_updates[
|
|
||||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
|
||||||
with torch.no_grad():
|
|
||||||
self.sd.text_encoder.get_input_embeddings().weight[
|
|
||||||
index_no_updates
|
|
||||||
] = self.orig_embeds_params[index_no_updates]
|
|
||||||
|
|
||||||
loss_dict = OrderedDict(
|
loss_dict = OrderedDict(
|
||||||
{'loss': loss.item()}
|
{'loss': loss.item()}
|
||||||
|
|||||||
@@ -26,11 +26,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.sd.vae.to(self.device_torch)
|
self.sd.vae.to(self.device_torch)
|
||||||
|
|
||||||
# textual inversion
|
# textual inversion
|
||||||
if self.embedding is not None:
|
# if self.embedding is not None:
|
||||||
# keep original embeddings as reference
|
|
||||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
|
||||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||||
self.sd.text_encoder.train()
|
# self.sd.text_encoder.train()
|
||||||
|
|
||||||
def hook_train_loop(self, batch):
|
def hook_train_loop(self, batch):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
@@ -103,13 +101,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
self.embedding.restore_embeddings()
|
||||||
index_no_updates[
|
|
||||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
|
||||||
with torch.no_grad():
|
|
||||||
self.sd.text_encoder.get_input_embeddings().weight[
|
|
||||||
index_no_updates
|
|
||||||
] = self.orig_embeds_params[index_no_updates]
|
|
||||||
|
|
||||||
loss_dict = OrderedDict(
|
loss_dict = OrderedDict(
|
||||||
{'loss': loss.item()}
|
{'loss': loss.item()}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import OrderedDict
|
|||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from lycoris.config import PRESET
|
# from lycoris.config import PRESET
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from toolkit.data_loader import get_dataloader_from_datasets
|
from toolkit.data_loader import get_dataloader_from_datasets
|
||||||
@@ -126,7 +126,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# to hold network if there is one
|
# to hold network if there is one
|
||||||
self.network: Union[Network, None] = None
|
self.network: Union[Network, None] = None
|
||||||
self.embedding = None
|
self.embedding: Union[Embedding, None] = None
|
||||||
|
|
||||||
def sample(self, step=None, is_first=False):
|
def sample(self, step=None, is_first=False):
|
||||||
sample_folder = os.path.join(self.save_root, 'samples')
|
sample_folder = os.path.join(self.save_root, 'samples')
|
||||||
@@ -261,13 +261,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.network_config.normalize:
|
if self.network_config.normalize:
|
||||||
# apply the normalization
|
# apply the normalization
|
||||||
self.network.apply_stored_normalizer()
|
self.network.apply_stored_normalizer()
|
||||||
|
|
||||||
|
# if we are doing embedding training as well, add that
|
||||||
|
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||||
self.network.save_weights(
|
self.network.save_weights(
|
||||||
file_path,
|
file_path,
|
||||||
dtype=get_torch_dtype(self.save_config.dtype),
|
dtype=get_torch_dtype(self.save_config.dtype),
|
||||||
metadata=save_meta
|
metadata=save_meta,
|
||||||
|
extra_state_dict=embedding_dict
|
||||||
)
|
)
|
||||||
self.network.multiplier = prev_multiplier
|
self.network.multiplier = prev_multiplier
|
||||||
|
# if we have an embedding as well, pair it with the network
|
||||||
elif self.embedding is not None:
|
elif self.embedding is not None:
|
||||||
|
# for combo, above will get it
|
||||||
# set current step
|
# set current step
|
||||||
self.embedding.step = self.step_num
|
self.embedding.step = self.step_num
|
||||||
# change filename to pt if that is set
|
# change filename to pt if that is set
|
||||||
@@ -330,16 +336,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
def load_weights(self, path):
|
def load_weights(self, path):
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
self.network.load_weights(path)
|
extra_weights = self.network.load_weights(path)
|
||||||
meta = load_metadata_from_safetensors(path)
|
meta = load_metadata_from_safetensors(path)
|
||||||
# if 'training_info' in Orderdict keys
|
# if 'training_info' in Orderdict keys
|
||||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||||
self.step_num = meta['training_info']['step']
|
self.step_num = meta['training_info']['step']
|
||||||
self.start_step = self.step_num
|
self.start_step = self.step_num
|
||||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||||
|
return extra_weights
|
||||||
else:
|
else:
|
||||||
print("load_weights not implemented for non-network models")
|
print("load_weights not implemented for non-network models")
|
||||||
|
return None
|
||||||
|
|
||||||
def process_general_training_batch(self, batch):
|
def process_general_training_batch(self, batch):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -479,9 +486,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
NetworkClass = LycorisSpecialNetwork
|
NetworkClass = LycorisSpecialNetwork
|
||||||
is_lycoris = True
|
is_lycoris = True
|
||||||
|
|
||||||
if is_lycoris:
|
# if is_lycoris:
|
||||||
preset = PRESET['full']
|
# preset = PRESET['full']
|
||||||
# NetworkClass.apply_preset(preset)
|
# NetworkClass.apply_preset(preset)
|
||||||
|
|
||||||
self.network = NetworkClass(
|
self.network = NetworkClass(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -533,12 +540,25 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.network.is_normalizing = self.network_config.normalize
|
self.network.is_normalizing = self.network_config.normalize
|
||||||
|
|
||||||
latest_save_path = self.get_latest_save_path()
|
latest_save_path = self.get_latest_save_path()
|
||||||
|
extra_weights = None
|
||||||
if latest_save_path is not None:
|
if latest_save_path is not None:
|
||||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||||
self.print(f"Loading from {latest_save_path}")
|
self.print(f"Loading from {latest_save_path}")
|
||||||
self.load_weights(latest_save_path)
|
extra_weights = self.load_weights(latest_save_path)
|
||||||
self.network.multiplier = 1.0
|
self.network.multiplier = 1.0
|
||||||
|
|
||||||
|
if self.embed_config is not None:
|
||||||
|
# we are doing embedding training as well
|
||||||
|
self.embedding = Embedding(
|
||||||
|
sd=self.sd,
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
state_dict=extra_weights
|
||||||
|
)
|
||||||
|
params.append({
|
||||||
|
'params': self.embedding.get_trainable_params(),
|
||||||
|
'lr': self.train_config.embedding_lr
|
||||||
|
})
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
elif self.embed_config is not None:
|
elif self.embed_config is not None:
|
||||||
self.embedding = Embedding(
|
self.embedding = Embedding(
|
||||||
|
|||||||
@@ -16,4 +16,7 @@ toml
|
|||||||
albumentations
|
albumentations
|
||||||
pydantic
|
pydantic
|
||||||
omegaconf
|
omegaconf
|
||||||
k-diffusion
|
k-diffusion
|
||||||
|
open_clip_torch
|
||||||
|
timm
|
||||||
|
prodigyopt
|
||||||
@@ -1,6 +1,10 @@
|
|||||||
from typing import Type, List, Union
|
from typing import Type, List, Union, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class BucketResolution(TypedDict):
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
|
||||||
BucketResolution = Type[{"width": int, "height": int}]
|
|
||||||
|
|
||||||
# resolutions SDXL was trained on with a 1024x1024 base resolution
|
# resolutions SDXL was trained on with a 1024x1024 base resolution
|
||||||
resolutions_1024: List[BucketResolution] = [
|
resolutions_1024: List[BucketResolution] = [
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class TrainConfig:
|
|||||||
self.lr = kwargs.get('lr', 1e-6)
|
self.lr = kwargs.get('lr', 1e-6)
|
||||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||||
|
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
||||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, List, Dict, Union
|
from typing import TYPE_CHECKING, List, Dict, Union
|
||||||
|
|
||||||
|
from toolkit.buckets import get_bucket_for_image_size
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -102,54 +103,21 @@ class BucketsMixin:
|
|||||||
width = file_item.crop_width
|
width = file_item.crop_width
|
||||||
height = file_item.crop_height
|
height = file_item.crop_height
|
||||||
|
|
||||||
# determine new resolution to have the same number of pixels
|
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution)
|
||||||
current_pixels = width * height
|
|
||||||
if current_pixels == total_pixels:
|
# set the scaling height and with to match smallest size, and keep aspect ratio
|
||||||
file_item.scale_to_width = width
|
if width > height:
|
||||||
file_item.scale_to_height = height
|
file_item.scale_height = bucket_resolution["height"]
|
||||||
file_item.crop_width = width
|
file_item.scale_width = int(width * (bucket_resolution["height"] / height))
|
||||||
file_item.crop_height = height
|
|
||||||
new_width = width
|
|
||||||
new_height = height
|
|
||||||
else:
|
else:
|
||||||
|
file_item.scale_width = bucket_resolution["width"]
|
||||||
|
file_item.scale_height = int(height * (bucket_resolution["width"] / width))
|
||||||
|
|
||||||
aspect_ratio = width / height
|
file_item.crop_height = bucket_resolution["height"]
|
||||||
new_height = int(math.sqrt(total_pixels / aspect_ratio))
|
file_item.crop_width = bucket_resolution["width"]
|
||||||
new_width = int(aspect_ratio * new_height)
|
|
||||||
|
|
||||||
# increase smallest one to be divisible by bucket_tolerance and increase the other to match
|
new_width = bucket_resolution["width"]
|
||||||
if new_width < new_height:
|
new_height = bucket_resolution["height"]
|
||||||
# increase width
|
|
||||||
if new_width % bucket_tolerance != 0:
|
|
||||||
crop_amount = new_width % bucket_tolerance
|
|
||||||
new_width = new_width + (bucket_tolerance - crop_amount)
|
|
||||||
new_height = int(new_width / aspect_ratio)
|
|
||||||
else:
|
|
||||||
# increase height
|
|
||||||
if new_height % bucket_tolerance != 0:
|
|
||||||
crop_amount = new_height % bucket_tolerance
|
|
||||||
new_height = new_height + (bucket_tolerance - crop_amount)
|
|
||||||
new_width = int(aspect_ratio * new_height)
|
|
||||||
|
|
||||||
# Ensure that the total number of pixels remains the same.
|
|
||||||
# assert new_width * new_height == total_pixels
|
|
||||||
|
|
||||||
file_item.scale_to_width = new_width
|
|
||||||
file_item.scale_to_height = new_height
|
|
||||||
file_item.crop_width = new_width
|
|
||||||
file_item.crop_height = new_height
|
|
||||||
# make sure it is divisible by bucket_tolerance, decrease if not
|
|
||||||
if new_width % bucket_tolerance != 0:
|
|
||||||
crop_amount = new_width % bucket_tolerance
|
|
||||||
file_item.crop_width = new_width - crop_amount
|
|
||||||
else:
|
|
||||||
file_item.crop_width = new_width
|
|
||||||
|
|
||||||
if new_height % bucket_tolerance != 0:
|
|
||||||
crop_amount = new_height % bucket_tolerance
|
|
||||||
file_item.crop_height = new_height - crop_amount
|
|
||||||
else:
|
|
||||||
file_item.crop_height = new_height
|
|
||||||
|
|
||||||
# check if bucket exists, if not, create it
|
# check if bucket exists, if not, create it
|
||||||
bucket_key = f'{new_width}x{new_height}'
|
bucket_key = f'{new_width}x{new_height}'
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ class Embedding:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
embed_config: 'EmbeddingConfig'
|
embed_config: 'EmbeddingConfig',
|
||||||
|
state_dict: OrderedDict = None,
|
||||||
):
|
):
|
||||||
self.name = embed_config.trigger
|
self.name = embed_config.trigger
|
||||||
self.sd = sd
|
self.sd = sd
|
||||||
@@ -38,74 +39,112 @@ class Embedding:
|
|||||||
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
|
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
|
||||||
placeholder_tokens += additional_tokens
|
placeholder_tokens += additional_tokens
|
||||||
|
|
||||||
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
|
# handle dual tokenizer
|
||||||
if num_added_tokens != self.embed_config.tokens:
|
self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer]
|
||||||
raise ValueError(
|
self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [
|
||||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
self.sd.text_encoder]
|
||||||
" `placeholder_token` that is not already in the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the initializer_token, placeholder_token to ids
|
self.placeholder_token_ids = []
|
||||||
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
|
self.embedding_tokens = []
|
||||||
# if length of token ids is more than number of orm embedding tokens fill with *
|
|
||||||
if len(init_token_ids) > self.embed_config.tokens:
|
|
||||||
init_token_ids = init_token_ids[:self.embed_config.tokens]
|
|
||||||
elif len(init_token_ids) < self.embed_config.tokens:
|
|
||||||
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
|
|
||||||
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
|
|
||||||
|
|
||||||
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||||
|
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||||
|
if num_added_tokens != self.embed_config.tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||||
|
" `placeholder_token` that is not already in the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
# Convert the initializer_token, placeholder_token to ids
|
||||||
# todo SDXL has 2 text encoders, need to do both for all of this
|
init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
|
||||||
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
|
# if length of token ids is more than number of orm embedding tokens fill with *
|
||||||
|
if len(init_token_ids) > self.embed_config.tokens:
|
||||||
|
init_token_ids = init_token_ids[:self.embed_config.tokens]
|
||||||
|
elif len(init_token_ids) < self.embed_config.tokens:
|
||||||
|
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
|
||||||
|
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
|
||||||
|
|
||||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
|
||||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
self.placeholder_token_ids.append(placeholder_token_ids)
|
||||||
with torch.no_grad():
|
|
||||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
|
||||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
|
||||||
|
|
||||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
# returns the string to have in the prompt to trigger the embedding
|
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||||
def get_embedding_string(self):
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
return self.embedding_tokens
|
with torch.no_grad():
|
||||||
|
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
|
||||||
|
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||||
|
|
||||||
|
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||||
|
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
|
||||||
|
|
||||||
|
# backup text encoder embeddings
|
||||||
|
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
||||||
|
|
||||||
|
def restore_embeddings(self):
|
||||||
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
|
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||||
|
self.tokenizer_list,
|
||||||
|
self.orig_embeds_params,
|
||||||
|
self.placeholder_token_ids):
|
||||||
|
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||||
|
index_no_updates[
|
||||||
|
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||||
|
with torch.no_grad():
|
||||||
|
text_encoder.get_input_embeddings().weight[
|
||||||
|
index_no_updates
|
||||||
|
] = orig_embeds[index_no_updates]
|
||||||
|
|
||||||
def get_trainable_params(self):
|
def get_trainable_params(self):
|
||||||
# todo only get this one as we could have more than one
|
params = []
|
||||||
return self.sd.text_encoder.get_input_embeddings().parameters()
|
for text_encoder in self.text_encoder_list:
|
||||||
|
params += text_encoder.get_input_embeddings().parameters()
|
||||||
|
return params
|
||||||
|
|
||||||
# make setter and getter for vec
|
def _get_vec(self, text_encoder_idx=0):
|
||||||
@property
|
|
||||||
def vec(self):
|
|
||||||
# should we get params instead
|
# should we get params instead
|
||||||
# create vector from token embeds
|
# create vector from token embeds
|
||||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||||
# stack the tokens along batch axis adding that axis
|
# stack the tokens along batch axis adding that axis
|
||||||
new_vector = torch.stack(
|
new_vector = torch.stack(
|
||||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
[token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]],
|
||||||
dim=0
|
dim=0
|
||||||
)
|
)
|
||||||
return new_vector
|
return new_vector
|
||||||
|
|
||||||
@vec.setter
|
def _set_vec(self, new_vector, text_encoder_idx=0):
|
||||||
def vec(self, new_vector):
|
|
||||||
# shape is (1, 768) for SD 1.5 for 1 token
|
# shape is (1, 768) for SD 1.5 for 1 token
|
||||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
|
||||||
for i in range(new_vector.shape[0]):
|
for i in range(new_vector.shape[0]):
|
||||||
# apply the weights to the placeholder tokens while preserving gradient
|
# apply the weights to the placeholder tokens while preserving gradient
|
||||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
|
||||||
x = 1
|
|
||||||
|
# make setter and getter for vec
|
||||||
|
@property
|
||||||
|
def vec(self):
|
||||||
|
return self._get_vec(0)
|
||||||
|
|
||||||
|
@vec.setter
|
||||||
|
def vec(self, new_vector):
|
||||||
|
self._set_vec(new_vector, 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vec2(self):
|
||||||
|
return self._get_vec(1)
|
||||||
|
|
||||||
|
@vec2.setter
|
||||||
|
def vec2(self, new_vector):
|
||||||
|
self._set_vec(new_vector, 1)
|
||||||
|
|
||||||
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||||
output_prompt = prompt
|
output_prompt = prompt
|
||||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||||
|
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens]
|
||||||
|
|
||||||
replace_with = self.embedding_tokens if expand_token else self.trigger
|
replace_with = embedding_tokens if expand_token else self.trigger
|
||||||
if to_replace_list is None:
|
if to_replace_list is None:
|
||||||
to_replace_list = default_replacements
|
to_replace_list = default_replacements
|
||||||
else:
|
else:
|
||||||
@@ -132,6 +171,17 @@ class Embedding:
|
|||||||
|
|
||||||
return output_prompt
|
return output_prompt
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if self.sd.is_xl:
|
||||||
|
state_dict = OrderedDict()
|
||||||
|
state_dict['clip_l'] = self.vec
|
||||||
|
state_dict['clip_g'] = self.vec2
|
||||||
|
else:
|
||||||
|
state_dict = OrderedDict()
|
||||||
|
state_dict['emb_params'] = self.vec
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
# todo check to see how to get the vector out of the embedding
|
# todo check to see how to get the vector out of the embedding
|
||||||
|
|
||||||
@@ -145,13 +195,14 @@ class Embedding:
|
|||||||
"sd_checkpoint_name": None,
|
"sd_checkpoint_name": None,
|
||||||
"notes": None,
|
"notes": None,
|
||||||
}
|
}
|
||||||
|
# TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl
|
||||||
if filename.endswith('.pt'):
|
if filename.endswith('.pt'):
|
||||||
torch.save(embedding_data, filename)
|
torch.save(embedding_data, filename)
|
||||||
elif filename.endswith('.bin'):
|
elif filename.endswith('.bin'):
|
||||||
torch.save(embedding_data, filename)
|
torch.save(embedding_data, filename)
|
||||||
elif filename.endswith('.safetensors'):
|
elif filename.endswith('.safetensors'):
|
||||||
# save the embedding as a safetensors file
|
# save the embedding as a safetensors file
|
||||||
state_dict = {"emb_params": self.vec}
|
state_dict = self.state_dict()
|
||||||
# add all embedding data (except string_to_param), to metadata
|
# add all embedding data (except string_to_param), to metadata
|
||||||
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
|
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
|
||||||
metadata["string_to_param"] = {"*": "emb_params"}
|
metadata["string_to_param"] = {"*": "emb_params"}
|
||||||
@@ -163,6 +214,7 @@ class Embedding:
|
|||||||
path = os.path.realpath(file_path)
|
path = os.path.realpath(file_path)
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
name, ext = os.path.splitext(filename)
|
name, ext = os.path.splitext(filename)
|
||||||
|
tensors = {}
|
||||||
ext = ext.upper()
|
ext = ext.upper()
|
||||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
_, second_ext = os.path.splitext(name)
|
_, second_ext = os.path.splitext(name)
|
||||||
@@ -170,10 +222,12 @@ class Embedding:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if ext in ['.BIN', '.PT']:
|
if ext in ['.BIN', '.PT']:
|
||||||
|
# todo check this
|
||||||
|
if self.sd.is_xl:
|
||||||
|
raise Exception("XL not supported yet for bin, pt")
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
elif ext in ['.SAFETENSORS']:
|
elif ext in ['.SAFETENSORS']:
|
||||||
# rebuild the embedding from the safetensors file if it has it
|
# rebuild the embedding from the safetensors file if it has it
|
||||||
tensors = {}
|
|
||||||
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
|
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
@@ -217,4 +271,8 @@ class Embedding:
|
|||||||
if 'step' in data:
|
if 'step' in data:
|
||||||
self.step = int(data['step'])
|
self.step = int(data['step'])
|
||||||
|
|
||||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
if self.sd.is_xl:
|
||||||
|
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
||||||
|
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||||
|
|||||||
@@ -228,9 +228,15 @@ class ToolkitNetworkMixin:
|
|||||||
|
|
||||||
return keymap
|
return keymap
|
||||||
|
|
||||||
def save_weights(self: Network, file, dtype=torch.float16, metadata=None):
|
def save_weights(
|
||||||
|
self: Network,
|
||||||
|
file, dtype=torch.float16,
|
||||||
|
metadata=None,
|
||||||
|
extra_state_dict: Optional[OrderedDict] = None
|
||||||
|
):
|
||||||
keymap = self.get_keymap()
|
keymap = self.get_keymap()
|
||||||
|
|
||||||
|
|
||||||
save_keymap = {}
|
save_keymap = {}
|
||||||
if keymap is not None:
|
if keymap is not None:
|
||||||
for ldm_key, diffusers_key in keymap.items():
|
for ldm_key, diffusers_key in keymap.items():
|
||||||
@@ -249,6 +255,13 @@ class ToolkitNetworkMixin:
|
|||||||
save_key = save_keymap[key] if key in save_keymap else key
|
save_key = save_keymap[key] if key in save_keymap else key
|
||||||
save_dict[save_key] = v
|
save_dict[save_key] = v
|
||||||
|
|
||||||
|
if extra_state_dict is not None:
|
||||||
|
# add extra items to state dict
|
||||||
|
for key in list(extra_state_dict.keys()):
|
||||||
|
v = extra_state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
save_dict[key] = v
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = OrderedDict()
|
metadata = OrderedDict()
|
||||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||||
@@ -275,8 +288,21 @@ class ToolkitNetworkMixin:
|
|||||||
load_key = keymap[key] if key in keymap else key
|
load_key = keymap[key] if key in keymap else key
|
||||||
load_sd[load_key] = value
|
load_sd[load_key] = value
|
||||||
|
|
||||||
|
# extract extra items from state dict
|
||||||
|
current_state_dict = self.state_dict()
|
||||||
|
extra_dict = OrderedDict()
|
||||||
|
to_delete = []
|
||||||
|
for key in list(load_sd.keys()):
|
||||||
|
if key not in current_state_dict:
|
||||||
|
extra_dict[key] = load_sd[key]
|
||||||
|
to_delete.append(key)
|
||||||
|
for key in to_delete:
|
||||||
|
del load_sd[key]
|
||||||
|
|
||||||
info = self.load_state_dict(load_sd, False)
|
info = self.load_state_dict(load_sd, False)
|
||||||
return info
|
if len(extra_dict.keys()) == 0:
|
||||||
|
extra_dict = None
|
||||||
|
return extra_dict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multiplier(self) -> Union[float, List[float]]:
|
def multiplier(self) -> Union[float, List[float]]:
|
||||||
|
|||||||
@@ -442,21 +442,25 @@ class StableDiffusion:
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||||
bs, ch, h, w = list(latents.shape)
|
|
||||||
|
|
||||||
height = h * VAE_SCALE_FACTOR
|
|
||||||
width = w * VAE_SCALE_FACTOR
|
|
||||||
|
|
||||||
dtype = latents.dtype
|
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
prompt_ids = train_tools.get_add_time_ids(
|
bs, ch, h, w = list(latents.shape)
|
||||||
height,
|
|
||||||
width,
|
height = h * VAE_SCALE_FACTOR
|
||||||
dynamic_crops=False, # look into this
|
width = w * VAE_SCALE_FACTOR
|
||||||
dtype=dtype,
|
|
||||||
).to(self.device_torch, dtype=dtype)
|
dtype = latents.dtype
|
||||||
return prompt_ids
|
# just do it without any cropping nonsense
|
||||||
|
target_size = (height, width)
|
||||||
|
original_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
add_time_ids = torch.tensor([add_time_ids])
|
||||||
|
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||||
|
|
||||||
|
batch_time_ids = torch.cat(
|
||||||
|
[add_time_ids for _ in range(bs)]
|
||||||
|
)
|
||||||
|
return batch_time_ids
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -682,7 +686,7 @@ class StableDiffusion:
|
|||||||
if self.vae.device == 'cpu':
|
if self.vae.device == 'cpu':
|
||||||
self.vae.to(self.device)
|
self.vae.to(self.device)
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(device, dtype=dtype)
|
||||||
latents = latents / 0.18215
|
latents = latents / self.vae.config['scaling_factor']
|
||||||
images = self.vae.decode(latents).sample
|
images = self.vae.decode(latents).sample
|
||||||
images = images.to(device, dtype=dtype)
|
images = images.to(device, dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user