mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added support for training lora, dreambooth, and fine tuning. Still need testing and docs
This commit is contained in:
@@ -61,11 +61,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
if dataset.is_reg:
|
||||
if self.datasets_reg is None:
|
||||
self.datasets_reg = []
|
||||
self.datasets_reg.append(dataset)
|
||||
else:
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
@@ -112,6 +124,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
)
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt, self.trigger_word
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
@@ -275,6 +291,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size)
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
@@ -433,14 +451,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dataloader = None
|
||||
dataloader_iterator = None
|
||||
|
||||
if self.data_loader_reg is not None:
|
||||
dataloader_reg = self.data_loader_reg
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
else:
|
||||
dataloader_reg = None
|
||||
dataloader_iterator_reg = None
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
if dataloader is not None:
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
# todo, should we do something else here? like blow up balloons?
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user