mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 01:59:48 +00:00
Added ability to add masks to dataloader and sd trainer to adjust weight of image
This commit is contained in:
@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -27,6 +27,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
@@ -67,6 +68,7 @@ class FileItemDTO(
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_mask()
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
@@ -76,6 +78,8 @@ class DataLoaderBatchDTO:
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -100,6 +104,21 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_mask_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.mask_tensor is not None:
|
||||
base_mask_tensor = x.mask_tensor
|
||||
break
|
||||
mask_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.mask_tensor is None:
|
||||
mask_tensors.append(torch.zeros_like(base_mask_tensor))
|
||||
else:
|
||||
mask_tensors.append(x.mask_tensor)
|
||||
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user