mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from jobs.process import BaseExtensionProcess
|
||||||
|
|
||||||
|
|
||||||
|
def flush():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTools(BaseExtensionProcess):
|
||||||
|
|
||||||
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
|
||||||
|
raise NotImplementedError("This extension is not yet implemented")
|
||||||
25
extensions_built_in/dataset_tools/__init__.py
Normal file
25
extensions_built_in/dataset_tools/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||||
|
from toolkit.extension import Extension
|
||||||
|
|
||||||
|
|
||||||
|
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||||
|
class DatasetToolsExtension(Extension):
|
||||||
|
# uid must be unique, it is how the extension is identified
|
||||||
|
uid = "dataset_tools"
|
||||||
|
|
||||||
|
# name is the name of the extension for printing
|
||||||
|
name = "Dataset Tools"
|
||||||
|
|
||||||
|
# This is where your process class is loaded
|
||||||
|
# keep your imports in here so they don't slow down the rest of the program
|
||||||
|
@classmethod
|
||||||
|
def get_process(cls):
|
||||||
|
# import your process class here so it is only loaded when needed and return it
|
||||||
|
from .DatasetTools import DatasetTools
|
||||||
|
return DatasetTools
|
||||||
|
|
||||||
|
|
||||||
|
AI_TOOLKIT_EXTENSIONS = [
|
||||||
|
# you can put a list of extensions here
|
||||||
|
DatasetToolsExtension,
|
||||||
|
]
|
||||||
@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.trigger_word is not None:
|
if self.trigger_word is not None:
|
||||||
# just so auto1111 will pick it up
|
# just so auto1111 will pick it up
|
||||||
o_dict['ss_tag_frequency'] = {
|
o_dict['ss_tag_frequency'] = {
|
||||||
'actfig': {
|
[self.trigger_word ]: {
|
||||||
'actfig': 1
|
[self.trigger_word ]: 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else: # no network, embedding or adapter
|
else: # no network, embedding or adapter
|
||||||
# set the device state preset before getting params
|
# set the device state preset before getting params
|
||||||
self.sd.set_device_state(self.train_device_state_preset)
|
self.sd.set_device_state(self.train_device_state_preset)
|
||||||
# will only return savable weights and ones with grad
|
|
||||||
params = self.sd.prepare_optimizer_params(
|
params = self.get_params()
|
||||||
unet=self.train_config.train_unet,
|
if not params:
|
||||||
text_encoder=self.train_config.train_text_encoder,
|
# will only return savable weights and ones with grad
|
||||||
text_encoder_lr=self.train_config.lr,
|
params = self.sd.prepare_optimizer_params(
|
||||||
unet_lr=self.train_config.lr,
|
unet=self.train_config.train_unet,
|
||||||
default_lr=self.train_config.lr
|
text_encoder=self.train_config.train_text_encoder,
|
||||||
)
|
text_encoder_lr=self.train_config.lr,
|
||||||
|
unet_lr=self.train_config.lr,
|
||||||
|
default_lr=self.train_config.lr
|
||||||
|
)
|
||||||
# we may be using it for prompt injections
|
# we may be using it for prompt injections
|
||||||
if self.adapter_config is not None:
|
if self.adapter_config is not None:
|
||||||
self.setup_adapter()
|
self.setup_adapter()
|
||||||
|
|||||||
@@ -478,6 +478,7 @@ def get_dataloader_from_datasets(
|
|||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
has_buckets = False
|
has_buckets = False
|
||||||
|
is_caching_latents = False
|
||||||
|
|
||||||
dataset_config_list = []
|
dataset_config_list = []
|
||||||
# preprocess them all
|
# preprocess them all
|
||||||
@@ -497,6 +498,8 @@ def get_dataloader_from_datasets(
|
|||||||
datasets.append(dataset)
|
datasets.append(dataset)
|
||||||
if config.buckets:
|
if config.buckets:
|
||||||
has_buckets = True
|
has_buckets = True
|
||||||
|
if config.cache_latents or config.cache_latents_to_disk:
|
||||||
|
is_caching_latents = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"invalid dataset type: {config.type}")
|
raise ValueError(f"invalid dataset type: {config.type}")
|
||||||
|
|
||||||
@@ -512,6 +515,9 @@ def get_dataloader_from_datasets(
|
|||||||
)
|
)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
# check if is caching latents
|
||||||
|
|
||||||
|
|
||||||
if has_buckets:
|
if has_buckets:
|
||||||
# make sure they all have buckets
|
# make sure they all have buckets
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
|
|||||||
@@ -84,8 +84,22 @@ class DataLoaderBatchDTO:
|
|||||||
if is_latents_cached:
|
if is_latents_cached:
|
||||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||||
self.control_tensor: Union[torch.Tensor, None] = None
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
if self.file_items[0].control_tensor is not None:
|
# if self.file_items[0].control_tensor is not None:
|
||||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
# if any have a control tensor, we concatenate them
|
||||||
|
if any([x.control_tensor is not None for x in self.file_items]):
|
||||||
|
# find one to use as a base
|
||||||
|
base_control_tensor = None
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.control_tensor is not None:
|
||||||
|
base_control_tensor = x.control_tensor
|
||||||
|
break
|
||||||
|
control_tensors = []
|
||||||
|
for x in self.file_items:
|
||||||
|
if x.control_tensor is None:
|
||||||
|
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||||
|
else:
|
||||||
|
control_tensors.append(x.control_tensor)
|
||||||
|
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ class ImageProcessingDTOMixin:
|
|||||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
|
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
|
||||||
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
|
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
|
||||||
|
# todo look into this. This still happens sometimes
|
||||||
print('size mismatch')
|
print('size mismatch')
|
||||||
img = img.crop((
|
img = img.crop((
|
||||||
self.crop_x,
|
self.crop_x,
|
||||||
|
|||||||
@@ -5,14 +5,18 @@ import itertools
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentDecoder(nn.Module):
|
class LosslessLatentDecoder(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||||
super(LosslessLatentDecoder, self).__init__()
|
super(LosslessLatentDecoder, self).__init__()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||||
|
if trainable:
|
||||||
|
self.kernel = nn.Parameter(numpy_kernel)
|
||||||
|
else:
|
||||||
|
self.kernel = numpy_kernel
|
||||||
|
|
||||||
def build_kernel(self, in_channels, latent_depth):
|
def build_kernel(self, in_channels, latent_depth):
|
||||||
# my old code from tensorflow.
|
# my old code from tensorflow.
|
||||||
@@ -44,14 +48,18 @@ class LosslessLatentDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentEncoder(nn.Module):
|
class LosslessLatentEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||||
super(LosslessLatentEncoder, self).__init__()
|
super(LosslessLatentEncoder, self).__init__()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||||
|
if trainable:
|
||||||
|
self.kernel = nn.Parameter(numpy_kernel)
|
||||||
|
else:
|
||||||
|
self.kernel = numpy_kernel
|
||||||
|
|
||||||
|
|
||||||
def build_kernel(self, in_channels, latent_depth):
|
def build_kernel(self, in_channels, latent_depth):
|
||||||
@@ -82,13 +90,13 @@ class LosslessLatentEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentVAE(nn.Module):
|
class LosslessLatentVAE(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||||
super(LosslessLatentVAE, self).__init__()
|
super(LosslessLatentVAE, self).__init__()
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
|
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||||
encoder_out_channels = self.encoder.out_channels
|
encoder_out_channels = self.encoder.out_channels
|
||||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
|
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
latent = self.latent_encoder(x)
|
latent = self.latent_encoder(x)
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def flush():
|
|||||||
|
|
||||||
|
|
||||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||||
|
|
||||||
# if is type checking
|
# if is type checking
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
@@ -471,6 +471,7 @@ class StableDiffusion:
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
noise_offset=0.0,
|
noise_offset=0.0,
|
||||||
):
|
):
|
||||||
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
if height is None and pixel_height is None:
|
if height is None and pixel_height is None:
|
||||||
raise ValueError("height or pixel_height must be specified")
|
raise ValueError("height or pixel_height must be specified")
|
||||||
if width is None and pixel_width is None:
|
if width is None and pixel_width is None:
|
||||||
@@ -493,6 +494,7 @@ class StableDiffusion:
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||||
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
bs, ch, h, w = list(latents.shape)
|
bs, ch, h, w = list(latents.shape)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user