Allow loading auxillery images from dataloader

This commit is contained in:
Jaret Burkett
2023-09-30 07:28:23 -06:00
parent 8d9450ad7c
commit 085787b799
3 changed files with 89 additions and 8 deletions

View File

@@ -6,7 +6,8 @@ from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -21,9 +22,15 @@ def print_once(msg):
printed_messages.append(msg)
class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
def __init__(self, **kwargs):
super().__init__()
class FileItemDTO(
LatentCachingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.path = kwargs.get('path', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
# process width and height
@@ -58,6 +65,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
class DataLoaderBatchDTO:
@@ -73,6 +81,9 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]
@@ -95,5 +106,6 @@ class DataLoaderBatchDTO:
def cleanup(self):
del self.latents
del self.tensor
del self.control_tensor
for file_item in self.file_items:
file_item.cleanup()