mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Switched ip adapter dataloader to clip_image paths so the control paths can be used for training assistant adapters while training ip adapters
This commit is contained in:
@@ -8,7 +8,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -28,6 +28,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
ClipImageFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
UnconditionalFileItemDTOMixin,
|
||||
@@ -71,6 +72,7 @@ class FileItemDTO(
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_clip_image()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
|
||||
@@ -83,6 +85,7 @@ class DataLoaderBatchDTO:
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -113,6 +116,21 @@ class DataLoaderBatchDTO:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_clip_image_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.clip_image_tensor is not None:
|
||||
base_clip_image_tensor = x.clip_image_tensor
|
||||
break
|
||||
clip_image_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_tensor is None:
|
||||
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
|
||||
else:
|
||||
clip_image_tensors.append(x.clip_image_tensor)
|
||||
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
|
||||
|
||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_mask_tensor = None
|
||||
|
||||
Reference in New Issue
Block a user