mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Hude rework to move the batch to a DTO to make it far more modular to the future ui
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Union
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -23,7 +24,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
@@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
@@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def process_general_training_batch(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
if isinstance(dataset_config, list):
|
||||
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
|
||||
else:
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
imgs = batch.tensor
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
@@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
@@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
is_reg_step = False
|
||||
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
|
||||
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
@@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
if is_sample_step:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
if is_save_step:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
@@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# apply network normalizer if we are using it
|
||||
if self.network is not None and self.network.is_normalizing:
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class LoRAHack:
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'suppression')
|
||||
|
||||
|
||||
class TrainLoRAHack(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# we don't need text encoder so move it to cpu
|
||||
self.sd.text_encoder.to("cpu")
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
if self.hack_config.type == 'suppression':
|
||||
# set all params to self.current_suppression
|
||||
params = self.network.parameters()
|
||||
for param in params:
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = noise * 0.001
|
||||
|
||||
|
||||
def supress_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'sup': 0.0}
|
||||
)
|
||||
# increase noise
|
||||
for param in self.network.parameters():
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = param.data + noise * 0.001
|
||||
|
||||
|
||||
|
||||
return loss_dict
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
if self.hack_config.type == 'suppression':
|
||||
return self.supress_loop()
|
||||
else:
|
||||
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
|
||||
# end hook_train_loop
|
||||
@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
|
||||
from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainSliderProcessOld import TrainSliderProcessOld
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
|
||||
Reference in New Issue
Block a user