mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class DatasetTools(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
raise NotImplementedError("This extension is not yet implemented")
|
||||
25
extensions_built_in/dataset_tools/__init__.py
Normal file
25
extensions_built_in/dataset_tools/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class DatasetToolsExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "dataset_tools"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Dataset Tools"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .DatasetTools import DatasetTools
|
||||
return DatasetTools
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
DatasetToolsExtension,
|
||||
]
|
||||
@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.trigger_word is not None:
|
||||
# just so auto1111 will pick it up
|
||||
o_dict['ss_tag_frequency'] = {
|
||||
'actfig': {
|
||||
'actfig': 1
|
||||
[self.trigger_word ]: {
|
||||
[self.trigger_word ]: 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
|
||||
@@ -478,6 +478,7 @@ def get_dataloader_from_datasets(
|
||||
|
||||
datasets = []
|
||||
has_buckets = False
|
||||
is_caching_latents = False
|
||||
|
||||
dataset_config_list = []
|
||||
# preprocess them all
|
||||
@@ -497,6 +498,8 @@ def get_dataloader_from_datasets(
|
||||
datasets.append(dataset)
|
||||
if config.buckets:
|
||||
has_buckets = True
|
||||
if config.cache_latents or config.cache_latents_to_disk:
|
||||
is_caching_latents = True
|
||||
else:
|
||||
raise ValueError(f"invalid dataset type: {config.type}")
|
||||
|
||||
@@ -512,6 +515,9 @@ def get_dataloader_from_datasets(
|
||||
)
|
||||
return batch
|
||||
|
||||
# check if is caching latents
|
||||
|
||||
|
||||
if has_buckets:
|
||||
# make sure they all have buckets
|
||||
for dataset in datasets:
|
||||
|
||||
@@ -84,8 +84,22 @@ class DataLoaderBatchDTO:
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
if self.file_items[0].control_tensor is not None:
|
||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_control_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is not None:
|
||||
base_control_tensor = x.control_tensor
|
||||
break
|
||||
control_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is None:
|
||||
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
@@ -317,6 +317,7 @@ class ImageProcessingDTOMixin:
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
|
||||
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
|
||||
# todo look into this. This still happens sometimes
|
||||
print('size mismatch')
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
|
||||
@@ -5,14 +5,18 @@ import itertools
|
||||
|
||||
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
@@ -44,14 +48,18 @@ class LosslessLatentDecoder(nn.Module):
|
||||
|
||||
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
@@ -82,13 +90,13 @@ class LosslessLatentEncoder(nn.Module):
|
||||
|
||||
|
||||
class LosslessLatentVAE(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentVAE, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
encoder_out_channels = self.encoder.out_channels
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.latent_encoder(x)
|
||||
|
||||
@@ -78,7 +78,7 @@ def flush():
|
||||
|
||||
|
||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
@@ -471,6 +471,7 @@ class StableDiffusion:
|
||||
batch_size=1,
|
||||
noise_offset=0.0,
|
||||
):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if height is None and pixel_height is None:
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
if width is None and pixel_width is None:
|
||||
@@ -493,6 +494,7 @@ class StableDiffusion:
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if self.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user