mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Massive speed increase. Added latent caching both to disk and to memory
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -13,11 +12,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
@@ -288,9 +289,14 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_config: 'DatasetConfig',
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
folder_path = dataset_config.folder_path
|
||||
@@ -298,6 +304,15 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = folder_path
|
||||
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
|
||||
self.sd = sd
|
||||
|
||||
if self.sd is None and self.is_caching_latents:
|
||||
raise ValueError(f"sd is required for caching latents")
|
||||
|
||||
self.caption_type = dataset_config.caption_ext
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
@@ -344,19 +359,21 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
def setup_epoch(self):
|
||||
# TODO: set this up to redo cropping and everything else
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
@@ -381,7 +398,11 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
def get_dataloader_from_datasets(
|
||||
dataset_options,
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
) -> DataLoader:
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
@@ -402,7 +423,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
for config in dataset_config_list:
|
||||
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size)
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
|
||||
datasets.append(dataset)
|
||||
if config.buckets:
|
||||
has_buckets = True
|
||||
@@ -432,14 +453,14 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=2
|
||||
num_workers=1
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
num_workers=1,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
Reference in New Issue
Block a user