Massive speed increase. Added latent caching both to disk and to memory

This commit is contained in:
Jaret Burkett
2023-09-10 08:54:49 -06:00
parent 41a3f63b72
commit 34bfeba229
10 changed files with 455 additions and 109 deletions

View File

@@ -1,11 +1,10 @@
import json
import os
import random
from typing import List
from typing import List, TYPE_CHECKING
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
@@ -13,11 +12,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
@@ -288,9 +289,14 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
def __init__(
self,
dataset_config: 'DatasetConfig',
batch_size=1,
sd: 'StableDiffusion' = None,
):
super().__init__()
self.dataset_config = dataset_config
folder_path = dataset_config.folder_path
@@ -298,6 +304,15 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
if self.dataset_path is None:
self.dataset_path = folder_path
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.sd = sd
if self.sd is None and self.is_caching_latents:
raise ValueError(f"sd is required for caching latents")
self.caption_type = dataset_config.caption_ext
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
@@ -344,19 +359,21 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
# print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
self.setup_epoch()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
self.setup_epoch()
def setup_epoch(self):
# TODO: set this up to redo cropping and everything else
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
def __len__(self):
if self.dataset_config.buckets:
@@ -381,7 +398,11 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
def get_dataloader_from_datasets(
dataset_options,
batch_size=1,
sd: 'StableDiffusion' = None,
) -> DataLoader:
if dataset_options is None or len(dataset_options) == 0:
return None
@@ -402,7 +423,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
for config in dataset_config_list:
if config.type == 'image':
dataset = AiToolkitDataset(config, batch_size=batch_size)
dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
datasets.append(dataset)
if config.buckets:
has_buckets = True
@@ -432,14 +453,14 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
drop_last=False,
shuffle=True,
collate_fn=dto_collation, # Use the custom collate function
num_workers=2
num_workers=1
)
else:
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
num_workers=1,
collate_fn=dto_collation
)
return data_loader