Fixed big issue with bucketing dataloader and added random cripping to a point of interest

This commit is contained in:
Jaret Burkett
2023-10-02 18:31:08 -06:00
parent 320e109c5f
commit 579650eaf8
6 changed files with 264 additions and 72 deletions

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -27,6 +27,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
@@ -70,20 +71,25 @@ class FileItemDTO(
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]