mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
WIP creating textual inversion training script
This commit is contained in:
@@ -5,6 +5,8 @@ from typing import Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
|
||||
@@ -20,7 +22,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -30,6 +32,7 @@ def flush():
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd: StableDiffusion
|
||||
embedding: Union[Embedding, None] = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
self.datasets = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
@@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.embedding = None
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -89,8 +103,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
|
||||
prompt = sample_config.prompts[i]
|
||||
|
||||
# add embedding if there is one
|
||||
if self.embedding is not None:
|
||||
# replace our name with the embedding
|
||||
if self.embed_config.trigger in prompt:
|
||||
# if the trigger is a part of the prompt, replace it with the token ids
|
||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
||||
if self.name in prompt:
|
||||
# if the name is in the prompt, replace it with the trigger
|
||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
||||
if "[name]" in prompt:
|
||||
# in [name] in prompt, replace it with the trigger
|
||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
||||
if self.embedding.get_embedding_string() not in prompt:
|
||||
# add it to the beginning of the prompt
|
||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=sample_config.prompts[i], # it will autoparse the prompt
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
@@ -175,6 +207,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
metadata=save_meta
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
elif self.embedding is not None:
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -197,6 +231,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch=None):
|
||||
# return loss
|
||||
return 0.0
|
||||
@@ -208,6 +245,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{self.job.name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
@@ -230,11 +272,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def run(self):
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
### HOOk ###
|
||||
self.before_dataset_load()
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -303,7 +355,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
elif self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
else:
|
||||
params = []
|
||||
|
||||
Reference in New Issue
Block a user