Hude rework to move the batch to a DTO to make it far more modular to the future ui

This commit is contained in:
Jaret Burkett
2023-08-29 10:22:19 -06:00
parent bd758ff203
commit 714854ee86
10 changed files with 286 additions and 232 deletions

View File

@@ -6,6 +6,7 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -23,7 +24,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
def flush():
@@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.trigger_word = self.get_conf('trigger_word', None)
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
if raw_datasets is not None and len(raw_datasets) > 0:
@@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.sd = StableDiffusion(
device=self.device,
@@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def process_general_training_batch(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
if isinstance(dataset_config, list):
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
else:
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
imgs = batch.tensor
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
conditioned_prompts = []
@@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
# set trainable params
params = self.embedding.get_trainable_params()
@@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
is_reg_step = False
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
@@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
if is_sample_step:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
if is_save_step:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
@@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
# end of step
self.step_num = step
# apply network normalizer if we are using it
if self.network is not None and self.network.is_normalizing:
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
self.sample(self.step_num + 1)
print("")
self.save()

View File

@@ -1,76 +0,0 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self, batch):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop

View File

@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess