Bug fixes. Added some functionality to help with private extensions

This commit is contained in:
Jaret Burkett
2023-10-05 07:09:34 -06:00
parent 579650eaf8
commit f73402473b
8 changed files with 99 additions and 20 deletions

View File

@@ -0,0 +1,20 @@
from collections import OrderedDict
import gc
import torch
from jobs.process import BaseExtensionProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class DatasetTools(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
def run(self):
super().run()
raise NotImplementedError("This extension is not yet implemented")

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class DatasetToolsExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "dataset_tools"
# name is the name of the extension for printing
name = "Dataset Tools"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .DatasetTools import DatasetTools
return DatasetTools
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
DatasetToolsExtension,
]

View File

@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.trigger_word is not None:
# just so auto1111 will pick it up
o_dict['ss_tag_frequency'] = {
'actfig': {
'actfig': 1
[self.trigger_word ]: {
[self.trigger_word ]: 1
}
}
@@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if not params:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
# we may be using it for prompt injections
if self.adapter_config is not None:
self.setup_adapter()

View File

@@ -478,6 +478,7 @@ def get_dataloader_from_datasets(
datasets = []
has_buckets = False
is_caching_latents = False
dataset_config_list = []
# preprocess them all
@@ -497,6 +498,8 @@ def get_dataloader_from_datasets(
datasets.append(dataset)
if config.buckets:
has_buckets = True
if config.cache_latents or config.cache_latents_to_disk:
is_caching_latents = True
else:
raise ValueError(f"invalid dataset type: {config.type}")
@@ -512,6 +515,9 @@ def get_dataloader_from_datasets(
)
return batch
# check if is caching latents
if has_buckets:
# make sure they all have buckets
for dataset in datasets:

View File

@@ -84,8 +84,22 @@ class DataLoaderBatchDTO:
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
# find one to use as a base
base_control_tensor = None
for x in self.file_items:
if x.control_tensor is not None:
base_control_tensor = x.control_tensor
break
control_tensors = []
for x in self.file_items:
if x.control_tensor is None:
control_tensors.append(torch.zeros_like(base_control_tensor))
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
except Exception as e:
print(e)
raise e

View File

@@ -317,6 +317,7 @@ class ImageProcessingDTOMixin:
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
# todo look into this. This still happens sometimes
print('size mismatch')
img = img.crop((
self.crop_x,

View File

@@ -5,14 +5,18 @@ import itertools
class LosslessLatentDecoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentDecoder, self).__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels // (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
if trainable:
self.kernel = nn.Parameter(numpy_kernel)
else:
self.kernel = numpy_kernel
def build_kernel(self, in_channels, latent_depth):
# my old code from tensorflow.
@@ -44,14 +48,18 @@ class LosslessLatentDecoder(nn.Module):
class LosslessLatentEncoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentEncoder, self).__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels * (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
if trainable:
self.kernel = nn.Parameter(numpy_kernel)
else:
self.kernel = numpy_kernel
def build_kernel(self, in_channels, latent_depth):
@@ -82,13 +90,13 @@ class LosslessLatentEncoder(nn.Module):
class LosslessLatentVAE(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentVAE, self).__init__()
self.latent_depth = latent_depth
self.in_channels = in_channels
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
encoder_out_channels = self.encoder.out_channels
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
def forward(self, x):
latent = self.latent_encoder(x)

View File

@@ -78,7 +78,7 @@ def flush():
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
# if is type checking
if typing.TYPE_CHECKING:
@@ -471,6 +471,7 @@ class StableDiffusion:
batch_size=1,
noise_offset=0.0,
):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None:
@@ -493,6 +494,7 @@ class StableDiffusion:
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if self.is_xl:
bs, ch, h, w = list(latents.shape)