Allow augmentations and targeting different loss types fron the config file

This commit is contained in:
Jaret Burkett
2023-10-18 03:04:57 -06:00
parent da6302ada8
commit 07bf7bd7de
6 changed files with 216 additions and 50 deletions

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -28,6 +28,7 @@ class FileItemDTO(
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]