Massive speed increase. Added latent caching both to disk and to memory

This commit is contained in:
Jaret Burkett
2023-09-10 08:54:49 -06:00
parent 41a3f63b72
commit 34bfeba229
10 changed files with 455 additions and 109 deletions

View File

@@ -6,7 +6,7 @@ from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -21,8 +21,9 @@ def print_once(msg):
printed_messages.append(msg)
class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
def __init__(self, **kwargs):
super().__init__()
self.path = kwargs.get('path', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
# process width and height
@@ -53,12 +54,22 @@ class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
def cleanup(self):
self.tensor = None
self.cleanup_latent()
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]
@@ -82,3 +93,4 @@ class DataLoaderBatchDTO:
self.tensor = None
for file_item in self.file_items:
file_item.cleanup()
del self.tensor