Hude rework to move the batch to a DTO to make it far more modular to the future ui

This commit is contained in:
Jaret Burkett
2023-08-29 10:22:19 -06:00
parent bd758ff203
commit 714854ee86
10 changed files with 286 additions and 232 deletions

View File

@@ -6,6 +6,7 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -23,7 +24,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
def flush():
@@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.trigger_word = self.get_conf('trigger_word', None)
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
if raw_datasets is not None and len(raw_datasets) > 0:
@@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.sd = StableDiffusion(
device=self.device,
@@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def process_general_training_batch(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
if isinstance(dataset_config, list):
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
else:
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
imgs = batch.tensor
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
conditioned_prompts = []
@@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
# set trainable params
params = self.embedding.get_trainable_params()
@@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
is_reg_step = False
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
@@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
if is_sample_step:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
if is_save_step:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
@@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
# end of step
self.step_num = step
# apply network normalizer if we are using it
if self.network is not None and self.network.is_normalizing:
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
self.sample(self.step_num + 1)
print("")
self.save()