mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Allow loading auxillery images from dataloader
This commit is contained in:
@@ -6,7 +6,8 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -21,9 +22,15 @@ def print_once(msg):
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
class FileItemDTO(
|
||||
LatentCachingFileItemDTOMixin,
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.path = kwargs.get('path', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
@@ -58,6 +65,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
@@ -73,6 +81,9 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
if self.file_items[0].control_tensor is not None:
|
||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
@@ -95,5 +106,6 @@ class DataLoaderBatchDTO:
|
||||
def cleanup(self):
|
||||
del self.latents
|
||||
del self.tensor
|
||||
del self.control_tensor
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user