Added ability to add masks to dataloader and sd trainer to adjust weight of image

This commit is contained in:
Jaret Burkett
2023-10-09 11:21:00 -06:00
parent 1d3de678aa
commit bb1d3793e3
4 changed files with 127 additions and 14 deletions

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -27,6 +27,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
@@ -67,6 +68,7 @@ class FileItemDTO(
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
self.cleanup_mask()
class DataLoaderBatchDTO:
@@ -76,6 +78,8 @@ class DataLoaderBatchDTO:
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -100,6 +104,21 @@ class DataLoaderBatchDTO:
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base
base_mask_tensor = None
for x in self.file_items:
if x.mask_tensor is not None:
base_mask_tensor = x.mask_tensor
break
mask_tensors = []
for x in self.file_items:
if x.mask_tensor is None:
mask_tensors.append(torch.zeros_like(base_mask_tensor))
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
except Exception as e:
print(e)
raise e