From f73402473be4af5436968191a461295c719631f6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 5 Oct 2023 07:09:34 -0600 Subject: [PATCH] Bug fixes. Added some functionality to help with private extensions --- .../dataset_tools/DatasetTools.py | 20 +++++++++++++++ extensions_built_in/dataset_tools/__init__.py | 25 +++++++++++++++++++ jobs/process/BaseSDTrainProcess.py | 23 +++++++++-------- toolkit/data_loader.py | 6 +++++ toolkit/data_transfer_object/data_loader.py | 18 +++++++++++-- toolkit/dataloader_mixins.py | 1 + toolkit/llvae.py | 22 ++++++++++------ toolkit/stable_diffusion_model.py | 4 ++- 8 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 extensions_built_in/dataset_tools/DatasetTools.py create mode 100644 extensions_built_in/dataset_tools/__init__.py diff --git a/extensions_built_in/dataset_tools/DatasetTools.py b/extensions_built_in/dataset_tools/DatasetTools.py new file mode 100644 index 00000000..d969b77a --- /dev/null +++ b/extensions_built_in/dataset_tools/DatasetTools.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +import gc +import torch +from jobs.process import BaseExtensionProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class DatasetTools(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + + raise NotImplementedError("This extension is not yet implemented") diff --git a/extensions_built_in/dataset_tools/__init__.py b/extensions_built_in/dataset_tools/__init__.py new file mode 100644 index 00000000..d466b2d8 --- /dev/null +++ b/extensions_built_in/dataset_tools/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class DatasetToolsExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "dataset_tools" + + # name is the name of the extension for printing + name = "Dataset Tools" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DatasetTools import DatasetTools + return DatasetTools + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + DatasetToolsExtension, +] diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e0e7a4d2..50d415af 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.trigger_word is not None: # just so auto1111 will pick it up o_dict['ss_tag_frequency'] = { - 'actfig': { - 'actfig': 1 + [self.trigger_word ]: { + [self.trigger_word ]: 1 } } @@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess): else: # no network, embedding or adapter # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) - # will only return savable weights and ones with grad - params = self.sd.prepare_optimizer_params( - unet=self.train_config.train_unet, - text_encoder=self.train_config.train_text_encoder, - text_encoder_lr=self.train_config.lr, - unet_lr=self.train_config.lr, - default_lr=self.train_config.lr - ) + + params = self.get_params() + if not params: + # will only return savable weights and ones with grad + params = self.sd.prepare_optimizer_params( + unet=self.train_config.train_unet, + text_encoder=self.train_config.train_text_encoder, + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr + ) # we may be using it for prompt injections if self.adapter_config is not None: self.setup_adapter() diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 451ff4d2..73655b58 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -478,6 +478,7 @@ def get_dataloader_from_datasets( datasets = [] has_buckets = False + is_caching_latents = False dataset_config_list = [] # preprocess them all @@ -497,6 +498,8 @@ def get_dataloader_from_datasets( datasets.append(dataset) if config.buckets: has_buckets = True + if config.cache_latents or config.cache_latents_to_disk: + is_caching_latents = True else: raise ValueError(f"invalid dataset type: {config.type}") @@ -512,6 +515,9 @@ def get_dataloader_from_datasets( ) return batch + # check if is caching latents + + if has_buckets: # make sure they all have buckets for dataset in datasets: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 5550dec3..684383f4 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -84,8 +84,22 @@ class DataLoaderBatchDTO: if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) self.control_tensor: Union[torch.Tensor, None] = None - if self.file_items[0].control_tensor is not None: - self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items]) + # if self.file_items[0].control_tensor is not None: + # if any have a control tensor, we concatenate them + if any([x.control_tensor is not None for x in self.file_items]): + # find one to use as a base + base_control_tensor = None + for x in self.file_items: + if x.control_tensor is not None: + base_control_tensor = x.control_tensor + break + control_tensors = [] + for x in self.file_items: + if x.control_tensor is None: + control_tensors.append(torch.zeros_like(base_control_tensor)) + else: + control_tensors.append(x.control_tensor) + self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) except Exception as e: print(e) raise e diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 8e13aceb..285c96c8 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -317,6 +317,7 @@ class ImageProcessingDTOMixin: img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: + # todo look into this. This still happens sometimes print('size mismatch') img = img.crop(( self.crop_x, diff --git a/toolkit/llvae.py b/toolkit/llvae.py index e1698ede..9d559bfe 100644 --- a/toolkit/llvae.py +++ b/toolkit/llvae.py @@ -5,14 +5,18 @@ import itertools class LosslessLatentDecoder(nn.Module): - def __init__(self, in_channels, latent_depth, dtype=torch.float32): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): super(LosslessLatentDecoder, self).__init__() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels // (latent_depth * latent_depth)) numpy_kernel = self.build_kernel(in_channels, latent_depth) - self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel def build_kernel(self, in_channels, latent_depth): # my old code from tensorflow. @@ -44,14 +48,18 @@ class LosslessLatentDecoder(nn.Module): class LosslessLatentEncoder(nn.Module): - def __init__(self, in_channels, latent_depth, dtype=torch.float32): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): super(LosslessLatentEncoder, self).__init__() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels * (latent_depth * latent_depth)) numpy_kernel = self.build_kernel(in_channels, latent_depth) - self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + if trainable: + self.kernel = nn.Parameter(numpy_kernel) + else: + self.kernel = numpy_kernel def build_kernel(self, in_channels, latent_depth): @@ -82,13 +90,13 @@ class LosslessLatentEncoder(nn.Module): class LosslessLatentVAE(nn.Module): - def __init__(self, in_channels, latent_depth, dtype=torch.float32): + def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False): super(LosslessLatentVAE, self).__init__() self.latent_depth = latent_depth self.in_channels = in_channels - self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype) + self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable) encoder_out_channels = self.encoder.out_channels - self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype) + self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable) def forward(self, x): latent = self.latent_encoder(x) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 98a67e7e..f1959beb 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -78,7 +78,7 @@ def flush(): UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 -VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 # if is type checking if typing.TYPE_CHECKING: @@ -471,6 +471,7 @@ class StableDiffusion: batch_size=1, noise_offset=0.0, ): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) if height is None and pixel_height is None: raise ValueError("height or pixel_height must be specified") if width is None and pixel_width is None: @@ -493,6 +494,7 @@ class StableDiffusion: return noise def get_time_ids_from_latents(self, latents: torch.Tensor): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) if self.is_xl: bs, ch, h, w = list(latents.shape)