Added inbuild plugins and made one for image referenced. WIP

This commit is contained in:
Jaret Burkett
2023-08-10 16:02:44 -06:00
parent df48f0a843
commit 1a7e346b41
12 changed files with 338 additions and 26 deletions

View File

@@ -0,0 +1,202 @@
import copy
import random
from collections import OrderedDict
import os
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.data_loader import PairedImageDataset
from toolkit.prompt_utils import concat_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
import gc
from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class ReferenceSliderConfig:
def __init__(self, **kwargs):
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
self.resolutions: List[int] = kwargs.get('resolutions', [512])
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.target_class: int = kwargs.get('target_class', '')
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
sd: StableDiffusion
data_loader: DataLoader = None
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.prompt_txt_list = None
self.step_num = 0
self.start_step = 0
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {}))
def load_datasets(self):
if self.data_loader is None:
print(f"Loading datasets")
datasets = []
for res in self.slider_config.resolutions:
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
config = {
'path': self.slider_config.slider_pair_folder,
'size': res,
'default_prompt': self.slider_config.target_class
}
image_dataset = PairedImageDataset(config)
datasets.append(image_dataset)
concatenated_dataset = ConcatDataset(datasets)
self.data_loader = DataLoader(
concatenated_dataset,
batch_size=self.train_config.batch_size,
shuffle=True,
num_workers=2
)
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
self.load_datasets()
pass
def hook_train_loop(self, batch):
with torch.no_grad():
imgs, prompts = batch
dtype = get_torch_dtype(self.train_config.dtype)
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
# split batched images in half so left is negative and right is positive
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
height = positive_images.shape[2]
width = positive_images.shape[3]
batch_size = positive_images.shape[0]
# encode the images
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
positive_latents = positive_latents * 0.18215
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
negative_latents = negative_latents * 0.18215
embedding_list = []
negative_embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
# just empty for now
# todo cache this?
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
negative_embedding_list.append(negative_embed)
conditional_embeds = concat_prompt_embeds(embedding_list)
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(neg, pos, gs, cts, dn):
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
neg, # negative prompt
pos, # positive prompt
self.train_config.batch_size,
),
timestep=cts,
guidance_scale=gs,
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
flush()
self.optimizer.zero_grad()
with self.network:
assert self.network.is_active
loss_list = []
for noisy_latents, network_multiplier in zip(
[noisy_positive_latents, noisy_negative_latents],
[1.0, -1.0],
):
# do positive first
self.network.multiplier = network_multiplier
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
)
if self.sd.is_v2: # check is vpred, don't want to track it down right now
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# todo add snr gamma here
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
loss_list.append(loss.item())
flush()
# apply gradients
optimizer.step()
lr_scheduler.step()
loss_float = sum(loss_list) / len(loss_list)
# reset network
self.network.multiplier = 1.0
loss_dict = OrderedDict(
{'loss': loss_float},
)
return loss_dict
# end hook_train_loop

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# We make a subclass of Extension
class ImageReferenceSliderTrainer(Extension):
# uid must be unique, it is how the extension is identified
uid = "image_reference_slider_trainer"
# name is the name of the extension for printing
name = "Image Reference Slider Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess
return ImageReferenceSliderTrainerProcess
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
ImageReferenceSliderTrainer
]

View File

@@ -1,6 +1,9 @@
import glob import glob
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Union
from torch.utils.data import DataLoader
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
@@ -54,6 +57,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
self.sd = StableDiffusion( self.sd = StableDiffusion(
device=self.device, device=self.device,
@@ -193,7 +197,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self): def hook_before_train_loop(self):
pass pass
def hook_train_loop(self): def hook_train_loop(self, batch=None):
# return loss # return loss
return 0.0 return 0.0
@@ -358,12 +362,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
iterable=range(0, self.train_config.steps), iterable=range(0, self.train_config.steps),
) )
if self.data_loader is not None:
dataloader = self.data_loader
dataloader_iterator = iter(dataloader)
else:
dataloader = None
dataloader_iterator = None
# self.step_num = 0 # self.step_num = 0
for step in range(self.step_num, self.train_config.steps): for step in range(self.step_num, self.train_config.steps):
# todo handle dataloader here maybe, not sure if dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
# todo, should we do something else here? like blow up balloons?
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else:
batch = None
### HOOK ### ### HOOK ###
loss_dict = self.hook_train_loop() loss_dict = self.hook_train_loop(batch)
flush()
if self.train_config.optimizer.lower().startswith('dadaptation') or \ if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'): self.train_config.optimizer.lower().startswith('prodigy'):

View File

@@ -29,11 +29,11 @@ class BaseTrainProcess(BaseProcess):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.progress_bar = None self.progress_bar = None
self.writer = None self.writer = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None)
self.save_root = os.path.join(self.training_folder, self.job.name) self.save_root = os.path.join(self.training_folder, self.name)
self.step = 0 self.step = 0
self.first_step = 0 self.first_step = 0
self.log_dir = self.get_conf('log_dir', self.job.log_dir) self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None)
self.setup_tensorboard() self.setup_tensorboard()
self.save_training_config() self.save_training_config()
@@ -62,7 +62,7 @@ class BaseTrainProcess(BaseProcess):
def save_training_config(self): def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True) os.makedirs(self.save_root, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml') save_dif = os.path.join(self.save_root, f'process_config_{timestamp}.yaml')
with open(save_dif, 'w') as f: with open(save_dif, 'w') as f:
yaml.dump(self.raw_process_config, f) yaml.dump(self.raw_process_config, f)

View File

@@ -68,7 +68,7 @@ class TrainLoRAHack(BaseSDTrainProcess):
return loss_dict return loss_dict
def hook_train_loop(self): def hook_train_loop(self, batch):
if self.hack_config.type == 'suppression': if self.hack_config.type == 'suppression':
return self.supress_loop() return self.supress_loop()
else: else:

View File

@@ -210,7 +210,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
flush() flush()
# end hook_before_train_loop # end hook_before_train_loop
def hook_train_loop(self): def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()

View File

@@ -173,7 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
flush() flush()
# end hook_before_train_loop # end hook_before_train_loop
def hook_train_loop(self): def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair # get a random pair

View File

@@ -221,7 +221,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess):
flush() flush()
# end hook_before_train_loop # end hook_before_train_loop
def hook_train_loop(self): def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair # get a random pair

View File

@@ -13,3 +13,4 @@ from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess from .GenerateProcess import GenerateProcess
from .BaseExtensionProcess import BaseExtensionProcess from .BaseExtensionProcess import BaseExtensionProcess
from .TrainESRGANProcess import TrainESRGANProcess from .TrainESRGANProcess import TrainESRGANProcess
from .BaseSDTrainProcess import BaseSDTrainProcess

View File

@@ -140,3 +140,65 @@ class AugmentedImageDataset(ImageDataset):
# return both # return image as 0 - 1 tensor # return both # return image as 0 - 1 tensor
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
class PairedImageDataset(Dataset):
def __init__(self, config):
super().__init__()
self.config = config
self.size = self.get_config('size', 512)
self.path = self.get_config('path', required=True)
self.default_prompt = self.get_config('default_prompt', '')
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
print(f" - Found {len(self.file_list)} images")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def __len__(self):
return len(self.file_list)
def get_config(self, key, default=None, required=False):
if key in self.config:
value = self.config[key]
return value
elif required:
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
else:
return default
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = self.default_prompt
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt

View File

@@ -25,25 +25,26 @@ class Extension(object):
def get_all_extensions() -> List[Extension]: def get_all_extensions() -> List[Extension]:
# Get the path of the "extensions" directory extension_folders = ['extensions', 'extensions_built_in']
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions")
# This will hold the classes from all extension modules # This will hold the classes from all extension modules
all_extension_classes: List[Extension] = [] all_extension_classes: List[Extension] = []
# Iterate over all directories (i.e., packages) in the "extensions" directory # Iterate over all directories (i.e., packages) in the "extensions" directory
for (_, name, _) in pkgutil.iter_modules([extensions_dir]): for sub_dir in extension_folders:
try: extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
# Import the module for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
module = importlib.import_module(f"extensions.{name}") try:
# Get the value of the AI_TOOLKIT_EXTENSIONS variable # Import the module
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) module = importlib.import_module(f"{sub_dir}.{name}")
# Check if the value is a list # Get the value of the AI_TOOLKIT_EXTENSIONS variable
if isinstance(extensions, list): extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
# Iterate over the list and add the classes to the main list # Check if the value is a list
all_extension_classes.extend(extensions) if isinstance(extensions, list):
except ImportError as e: # Iterate over the list and add the classes to the main list
print(f"Failed to import the {name} module. Error: {str(e)}") all_extension_classes.extend(extensions)
except ImportError as e:
print(f"Failed to import the {name} module. Error: {str(e)}")
return all_extension_classes return all_extension_classes