WIP creating textual inversion training script

This commit is contained in:
Jaret Burkett
2023-08-22 21:02:38 -06:00
parent 36ba08d3fa
commit 2e6c55c720
9 changed files with 746 additions and 6 deletions

View File

@@ -5,6 +5,8 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -20,7 +22,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig
def flush():
@@ -30,6 +32,7 @@ def flush():
class BaseSDTrainProcess(BaseTrainProcess):
sd: StableDiffusion
embedding: Union[Embedding, None] = None
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
super().__init__(process_id, job, config)
@@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
raw_datasets = self.get_conf('datasets', None)
self.datasets = None
if raw_datasets is not None and len(raw_datasets) > 0:
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
@@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network = None
self.embedding = None
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -89,8 +103,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
# add embedding if there is one
if self.embedding is not None:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
gen_img_config_list.append(GenerateImageConfig(
prompt=sample_config.prompts[i], # it will autoparse the prompt
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
@@ -175,6 +207,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
metadata=save_meta
)
self.network.multiplier = prev_multiplier
elif self.embedding is not None:
self.embedding.save(file_path)
else:
self.sd.save(
file_path,
@@ -197,6 +231,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def before_dataset_load(self):
pass
def hook_train_loop(self, batch=None):
# return loss
return 0.0
@@ -208,6 +245,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{self.job.name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
@@ -230,11 +272,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
def run(self):
# run base process run
BaseTrainProcess.run(self)
### HOOk ###
self.before_dataset_load()
# load datasets if passed in the root process
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
### HOOK ###
self.hook_before_model_load()
# run base sd process run
self.sd.load_model()
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -303,7 +355,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
elif self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path()
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# set trainable params
params = self.embedding.get_trainable_params()
else:
params = []