Switched ip adapter dataloader to clip_image paths so the control paths can be used for training assistant adapters while training ip adapters

This commit is contained in:
Jaret Burkett
2023-12-20 10:32:24 -07:00
parent dfb64b5957
commit 0f597f453e
5 changed files with 94 additions and 12 deletions

View File

@@ -8,7 +8,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -28,6 +28,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
ClipImageFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
@@ -71,6 +72,7 @@ class FileItemDTO(
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
self.cleanup_clip_image()
self.cleanup_mask()
self.cleanup_unconditional()
@@ -83,6 +85,7 @@ class DataLoaderBatchDTO:
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
@@ -113,6 +116,21 @@ class DataLoaderBatchDTO:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
if any([x.clip_image_tensor is not None for x in self.file_items]):
# find one to use as a base
base_clip_image_tensor = None
for x in self.file_items:
if x.clip_image_tensor is not None:
base_clip_image_tensor = x.clip_image_tensor
break
clip_image_tensors = []
for x in self.file_items:
if x.clip_image_tensor is None:
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
else:
clip_image_tensors.append(x.clip_image_tensor)
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base
base_mask_tensor = None