mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -52,41 +52,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
|
|||||||
# very loosely based on this. very loosely
|
# very loosely based on this. very loosely
|
||||||
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||||
|
|
||||||
conditioned_prompts = []
|
# make sure the embedding is in the prompts
|
||||||
|
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
|
||||||
|
x,
|
||||||
|
expand_token=True,
|
||||||
|
add_if_not_present=True,
|
||||||
|
) for x in prompts]
|
||||||
|
|
||||||
for prompt in prompts:
|
batch_size = imgs.shape[0]
|
||||||
# replace our name with the embedding
|
|
||||||
if self.embed_config.trigger in prompt:
|
|
||||||
# if the trigger is a part of the prompt, replace it with the token ids
|
|
||||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
|
||||||
if self.name in prompt:
|
|
||||||
# if the name is in the prompt, replace it with the trigger
|
|
||||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
|
||||||
if "[name]" in prompt:
|
|
||||||
# in [name] in prompt, replace it with the trigger
|
|
||||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
|
||||||
if self.embedding.get_embedding_string() not in prompt:
|
|
||||||
# add it to the beginning of the prompt
|
|
||||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
|
||||||
|
|
||||||
conditioned_prompts.append(prompt)
|
|
||||||
|
|
||||||
# # get embedding ids
|
|
||||||
# embedding_ids_list = [self.sd.tokenizer(
|
|
||||||
# text,
|
|
||||||
# padding="max_length",
|
|
||||||
# truncation=True,
|
|
||||||
# max_length=self.sd.tokenizer.model_max_length,
|
|
||||||
# return_tensors="pt",
|
|
||||||
# ).input_ids[0] for text in conditioned_prompts]
|
|
||||||
|
|
||||||
# hidden_states = []
|
|
||||||
# for embedding_ids, img in zip(embedding_ids_list, imgs):
|
|
||||||
# hidden_state = {
|
|
||||||
# "input_ids": embedding_ids,
|
|
||||||
# "pixel_values": img
|
|
||||||
# }
|
|
||||||
# hidden_states.append(hidden_state)
|
|
||||||
|
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||||
@@ -100,14 +73,14 @@ class TextualInversionTrainer(BaseSDTrainProcess):
|
|||||||
self.train_config.max_denoising_steps, device=self.device_torch
|
self.train_config.max_denoising_steps, device=self.device_torch
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
noise = self.sd.get_latent_noise(
|
noise = self.sd.get_latent_noise(
|
||||||
pixel_height=imgs.shape[2],
|
pixel_height=imgs.shape[2],
|
||||||
pixel_width=imgs.shape[3],
|
pixel_width=imgs.shape[3],
|
||||||
batch_size=self.train_config.batch_size,
|
batch_size=batch_size,
|
||||||
noise_offset=self.train_config.noise_offset
|
noise_offset=self.train_config.noise_offset
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
prompt = sample_config.prompts[i]
|
prompt = sample_config.prompts[i]
|
||||||
|
|
||||||
# add embedding if there is one
|
# add embedding if there is one
|
||||||
|
# note: diffusers will automatically expand the trigger to the number of added tokens
|
||||||
|
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
# replace our name with the embedding
|
prompt = self.embedding.inject_embedding_to_prompt(
|
||||||
if self.embed_config.trigger in prompt:
|
prompt,
|
||||||
# if the trigger is a part of the prompt, replace it with the token ids
|
)
|
||||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
|
||||||
if self.name in prompt:
|
|
||||||
# if the name is in the prompt, replace it with the trigger
|
|
||||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
|
||||||
if "[name]" in prompt:
|
|
||||||
# in [name] in prompt, replace it with the trigger
|
|
||||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
|
||||||
if self.embedding.get_embedding_string() not in prompt:
|
|
||||||
# add it to the beginning of the prompt
|
|
||||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
|
||||||
|
|
||||||
gen_img_config_list.append(GenerateImageConfig(
|
gen_img_config_list.append(GenerateImageConfig(
|
||||||
prompt=prompt, # it will autoparse the prompt
|
prompt=prompt, # it will autoparse the prompt
|
||||||
@@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
)
|
)
|
||||||
self.network.multiplier = prev_multiplier
|
self.network.multiplier = prev_multiplier
|
||||||
elif self.embedding is not None:
|
elif self.embedding is not None:
|
||||||
|
# set current step
|
||||||
|
self.embedding.step = self.step_num
|
||||||
|
# change filename to pt if that is set
|
||||||
|
if self.embed_config.save_format == "pt":
|
||||||
|
# replace extension
|
||||||
|
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||||
self.embedding.save(file_path)
|
self.embedding.save(file_path)
|
||||||
else:
|
else:
|
||||||
self.sd.save(
|
self.sd.save(
|
||||||
@@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
def before_dataset_load(self):
|
def before_dataset_load(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def hook_train_loop(self, batch=None):
|
def hook_train_loop(self, batch):
|
||||||
# return loss
|
# return loss
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if latest_save_path is not None:
|
if latest_save_path is not None:
|
||||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||||
|
|
||||||
|
# resume state from embedding
|
||||||
|
self.step_num = self.embedding.step
|
||||||
|
|
||||||
# set trainable params
|
# set trainable params
|
||||||
params = self.embedding.get_trainable_params()
|
params = self.embedding.get_trainable_params()
|
||||||
|
|
||||||
|
|||||||
@@ -54,9 +54,10 @@ class EmbeddingConfig:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||||
self.tokens = kwargs.get('tokens', 4)
|
self.tokens = kwargs.get('tokens', 4)
|
||||||
self.init_words = kwargs.get('init_phrase', '*')
|
self.init_words = kwargs.get('init_words', '*')
|
||||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||||
|
|
||||||
|
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||||
@@ -75,6 +76,7 @@ class TrainConfig:
|
|||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||||
|
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -165,6 +167,7 @@ class DatasetConfig:
|
|||||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||||
self.resolution: int = kwargs.get('resolution', 512)
|
self.resolution: int = kwargs.get('resolution', 512)
|
||||||
self.scale: float = kwargs.get('scale', 1.0)
|
self.scale: float = kwargs.get('scale', 1.0)
|
||||||
|
self.buckets: bool = kwargs.get('buckets', False)
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageConfig:
|
class GenerateImageConfig:
|
||||||
|
|||||||
@@ -14,6 +14,13 @@ import albumentations as A
|
|||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
from toolkit.dataloader_mixins import CaptionMixin
|
from toolkit.dataloader_mixins import CaptionMixin
|
||||||
|
|
||||||
|
BUCKET_STEPS = 64
|
||||||
|
|
||||||
|
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
|
||||||
|
# make sure resolution is divisible by 8
|
||||||
|
if resolution % 8 != 0:
|
||||||
|
resolution = resolution - (resolution % 8)
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset, CaptionMixin):
|
class ImageDataset(Dataset, CaptionMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -357,6 +364,7 @@ class AiToolkitDataset(Dataset, CaptionMixin):
|
|||||||
|
|
||||||
|
|
||||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||||
|
# TODO do bucketing
|
||||||
if dataset_options is None or len(dataset_options) == 0:
|
if dataset_options is None or len(dataset_options) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ class Embedding:
|
|||||||
):
|
):
|
||||||
self.name = embed_config.trigger
|
self.name = embed_config.trigger
|
||||||
self.sd = sd
|
self.sd = sd
|
||||||
|
self.trigger = embed_config.trigger
|
||||||
self.embed_config = embed_config
|
self.embed_config = embed_config
|
||||||
|
self.step = 0
|
||||||
# setup our embedding
|
# setup our embedding
|
||||||
# Add the placeholder token in tokenizer
|
# Add the placeholder token in tokenizer
|
||||||
placeholder_tokens = [self.embed_config.trigger]
|
placeholder_tokens = [self.embed_config.trigger]
|
||||||
@@ -64,10 +66,7 @@ class Embedding:
|
|||||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
||||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||||
|
|
||||||
# this doesnt seem to be used again
|
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||||
self.token_embeds = token_embeds
|
|
||||||
|
|
||||||
# replace "[name] with this. This triggers it in the text encoder
|
|
||||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
||||||
|
|
||||||
# returns the string to have in the prompt to trigger the embedding
|
# returns the string to have in the prompt to trigger the embedding
|
||||||
@@ -86,7 +85,7 @@ class Embedding:
|
|||||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||||
# stack the tokens along batch axis adding that axis
|
# stack the tokens along batch axis adding that axis
|
||||||
new_vector = torch.stack(
|
new_vector = torch.stack(
|
||||||
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
|
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
||||||
dim=0
|
dim=0
|
||||||
)
|
)
|
||||||
return new_vector
|
return new_vector
|
||||||
@@ -100,6 +99,39 @@ class Embedding:
|
|||||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
||||||
x = 1
|
x = 1
|
||||||
|
|
||||||
|
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||||
|
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||||
|
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||||
|
output_prompt = prompt
|
||||||
|
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
||||||
|
|
||||||
|
replace_with = self.embedding_tokens if expand_token else self.trigger
|
||||||
|
if to_replace_list is None:
|
||||||
|
to_replace_list = default_replacements
|
||||||
|
else:
|
||||||
|
to_replace_list += default_replacements
|
||||||
|
|
||||||
|
# remove duplicates
|
||||||
|
to_replace_list = list(set(to_replace_list))
|
||||||
|
|
||||||
|
# replace them all
|
||||||
|
for to_replace in to_replace_list:
|
||||||
|
# replace it
|
||||||
|
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||||
|
|
||||||
|
# see how many times replace_with is in the prompt
|
||||||
|
num_instances = prompt.count(replace_with)
|
||||||
|
|
||||||
|
if num_instances == 0 and add_if_not_present:
|
||||||
|
# add it to the beginning of the prompt
|
||||||
|
output_prompt = replace_with + " " + output_prompt
|
||||||
|
|
||||||
|
if num_instances > 1:
|
||||||
|
print(
|
||||||
|
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||||
|
|
||||||
|
return output_prompt
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
# todo check to see how to get the vector out of the embedding
|
# todo check to see how to get the vector out of the embedding
|
||||||
|
|
||||||
@@ -107,7 +139,7 @@ class Embedding:
|
|||||||
"string_to_token": {"*": 265},
|
"string_to_token": {"*": 265},
|
||||||
"string_to_param": {"*": self.vec},
|
"string_to_param": {"*": self.vec},
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"step": 0,
|
"step": self.step,
|
||||||
# todo get these
|
# todo get these
|
||||||
"sd_checkpoint": None,
|
"sd_checkpoint": None,
|
||||||
"sd_checkpoint_name": None,
|
"sd_checkpoint_name": None,
|
||||||
@@ -182,4 +214,7 @@ class Embedding:
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
if 'step' in data:
|
||||||
|
self.step = int(data['step'])
|
||||||
|
|
||||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||||
|
|||||||
@@ -435,7 +435,7 @@ class StableDiffusion:
|
|||||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||||
unconditional_embeddings, # negative embedding
|
unconditional_embeddings, # negative embedding
|
||||||
conditional_embeddings, # positive embedding
|
conditional_embeddings, # positive embedding
|
||||||
latents.shape[0], # batch size
|
1, # batch size
|
||||||
)
|
)
|
||||||
elif text_embeddings is None and conditional_embeddings is not None:
|
elif text_embeddings is None and conditional_embeddings is not None:
|
||||||
# not doing cfg
|
# not doing cfg
|
||||||
@@ -506,6 +506,17 @@ class StableDiffusion:
|
|||||||
|
|
||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
|
# check if we need to concat timesteps
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
ts_bs = timestep.shape[0]
|
||||||
|
if ts_bs != latent_model_input.shape[0]:
|
||||||
|
if ts_bs == 1:
|
||||||
|
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||||
|
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||||
|
timestep = torch.cat([timestep] * 2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
|
|||||||
Reference in New Issue
Block a user