mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added inbuild plugins and made one for image referenced. WIP
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
import copy
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional, Union, List
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from toolkit.data_loader import PairedImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class ReferenceSliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
|
||||
self.resolutions: List[int] = kwargs.get('resolutions', [512])
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.target_class: int = kwargs.get('target_class', '')
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.prompt_txt_list = None
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.slider_config = ReferenceSliderConfig(**self.get_conf('slider', {}))
|
||||
|
||||
def load_datasets(self):
|
||||
if self.data_loader is None:
|
||||
print(f"Loading datasets")
|
||||
datasets = []
|
||||
for res in self.slider_config.resolutions:
|
||||
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
|
||||
config = {
|
||||
'path': self.slider_config.slider_pair_folder,
|
||||
'size': res,
|
||||
'default_prompt': self.slider_config.target_class
|
||||
}
|
||||
image_dataset = PairedImageDataset(config)
|
||||
datasets.append(image_dataset)
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
self.data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=self.train_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
self.load_datasets()
|
||||
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# split batched images in half so left is negative and right is positive
|
||||
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
||||
|
||||
height = positive_images.shape[2]
|
||||
width = positive_images.shape[3]
|
||||
batch_size = positive_images.shape[0]
|
||||
|
||||
# encode the images
|
||||
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
|
||||
positive_latents = positive_latents * 0.18215
|
||||
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
|
||||
negative_latents = negative_latents * 0.18215
|
||||
|
||||
embedding_list = []
|
||||
negative_embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
# just empty for now
|
||||
# todo cache this?
|
||||
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
|
||||
negative_embedding_list.append(negative_embed)
|
||||
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
neg, # negative prompt
|
||||
pos, # positive prompt
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
|
||||
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
|
||||
|
||||
flush()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
loss_list = []
|
||||
for noisy_latents, network_multiplier in zip(
|
||||
[noisy_positive_latents, noisy_negative_latents],
|
||||
[1.0, -1.0],
|
||||
):
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
)
|
||||
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# todo add snr gamma here
|
||||
|
||||
loss = loss.mean()
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
loss_list.append(loss.item())
|
||||
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
loss_float = sum(loss_list) / len(loss_list)
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# We make a subclass of Extension
|
||||
class ImageReferenceSliderTrainer(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "image_reference_slider_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Image Reference Slider Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ImageReferenceSliderTrainerProcess import ImageReferenceSliderTrainerProcess
|
||||
return ImageReferenceSliderTrainerProcess
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
ImageReferenceSliderTrainer
|
||||
]
|
||||
@@ -1,6 +1,9 @@
|
||||
import glob
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -54,6 +57,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
@@ -193,7 +197,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self):
|
||||
def hook_train_loop(self, batch=None):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
@@ -358,12 +362,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
|
||||
if self.data_loader is not None:
|
||||
dataloader = self.data_loader
|
||||
dataloader_iterator = iter(dataloader)
|
||||
else:
|
||||
dataloader = None
|
||||
dataloader_iterator = None
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
if dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
# todo, should we do something else here? like blow up balloons?
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
batch = None
|
||||
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop()
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
|
||||
@@ -29,11 +29,11 @@ class BaseTrainProcess(BaseProcess):
|
||||
super().__init__(process_id, job, config)
|
||||
self.progress_bar = None
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None)
|
||||
self.save_root = os.path.join(self.training_folder, self.name)
|
||||
self.step = 0
|
||||
self.first_step = 0
|
||||
self.log_dir = self.get_conf('log_dir', self.job.log_dir)
|
||||
self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None)
|
||||
self.setup_tensorboard()
|
||||
self.save_training_config()
|
||||
|
||||
@@ -62,7 +62,7 @@ class BaseTrainProcess(BaseProcess):
|
||||
|
||||
def save_training_config(self):
|
||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
os.makedirs(self.training_folder, exist_ok=True)
|
||||
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml')
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
save_dif = os.path.join(self.save_root, f'process_config_{timestamp}.yaml')
|
||||
with open(save_dif, 'w') as f:
|
||||
yaml.dump(self.raw_process_config, f)
|
||||
|
||||
@@ -68,7 +68,7 @@ class TrainLoRAHack(BaseSDTrainProcess):
|
||||
|
||||
return loss_dict
|
||||
|
||||
def hook_train_loop(self):
|
||||
def hook_train_loop(self, batch):
|
||||
if self.hack_config.type == 'suppression':
|
||||
return self.supress_loop()
|
||||
else:
|
||||
|
||||
@@ -210,7 +210,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
@@ -173,7 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
|
||||
@@ -221,7 +221,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess):
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
|
||||
@@ -13,3 +13,4 @@ from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
from .BaseExtensionProcess import BaseExtensionProcess
|
||||
from .TrainESRGANProcess import TrainESRGANProcess
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess
|
||||
|
||||
@@ -140,3 +140,65 @@ class AugmentedImageDataset(ImageDataset):
|
||||
|
||||
# return both # return image as 0 - 1 tensor
|
||||
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
|
||||
|
||||
|
||||
class PairedImageDataset(Dataset):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.size = self.get_config('size', 512)
|
||||
self.path = self.get_config('path', required=True)
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def get_config(self, key, default=None, required=False):
|
||||
if key in self.config:
|
||||
value = self.config[key]
|
||||
return value
|
||||
elif required:
|
||||
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
|
||||
else:
|
||||
return default
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.file_list[index]
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
prompt = prompt.replace('\r', ', ')
|
||||
prompt_split = prompt.split(',')
|
||||
# remove empty strings
|
||||
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
||||
# join back together
|
||||
prompt = ', '.join(prompt_split)
|
||||
else:
|
||||
prompt = self.default_prompt
|
||||
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt
|
||||
|
||||
|
||||
@@ -25,25 +25,26 @@ class Extension(object):
|
||||
|
||||
|
||||
def get_all_extensions() -> List[Extension]:
|
||||
# Get the path of the "extensions" directory
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions")
|
||||
extension_folders = ['extensions', 'extensions_built_in']
|
||||
|
||||
# This will hold the classes from all extension modules
|
||||
all_extension_classes: List[Extension] = []
|
||||
|
||||
# Iterate over all directories (i.e., packages) in the "extensions" directory
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"extensions.{name}")
|
||||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
|
||||
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
|
||||
# Check if the value is a list
|
||||
if isinstance(extensions, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_extension_classes.extend(extensions)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
for sub_dir in extension_folders:
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"{sub_dir}.{name}")
|
||||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
|
||||
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
|
||||
# Check if the value is a list
|
||||
if isinstance(extensions, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_extension_classes.extend(extensions)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
|
||||
return all_extension_classes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user