Added support for training lora, dreambooth, and fine tuning. Still need testing and docs

This commit is contained in:
Jaret Burkett
2023-08-23 15:37:00 -06:00
parent e2c547f6c2
commit 7157c316af
8 changed files with 265 additions and 165 deletions

View File

@@ -61,11 +61,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.optimizer = None
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
self.data_loader_reg: Union[DataLoader, None] = None
self.trigger_word = self.get_conf('trigger_word', None)
raw_datasets = self.get_conf('datasets', None)
self.datasets = None
self.datasets_reg = None
if raw_datasets is not None and len(raw_datasets) > 0:
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
if dataset.is_reg:
if self.datasets_reg is None:
self.datasets_reg = []
self.datasets_reg.append(dataset)
else:
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
@@ -112,6 +124,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
)
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt, self.trigger_word
)
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
@@ -275,6 +291,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# load datasets if passed in the root process
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size)
### HOOK ###
self.hook_before_model_load()
@@ -433,14 +451,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
dataloader = None
dataloader_iterator = None
if self.data_loader_reg is not None:
dataloader_reg = self.data_loader_reg
dataloader_iterator_reg = iter(dataloader_reg)
else:
dataloader_reg = None
dataloader_iterator_reg = None
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
if dataloader is not None:
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
# todo, should we do something else here? like blow up balloons?
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else: