mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Handle inpainting training for control_lora adapter
This commit is contained in:
@@ -12,7 +12,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -34,6 +34,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
InpaintControlFileItemDTOMixin,
|
||||
ClipImageFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
@@ -108,6 +109,7 @@ class FileItemDTO(
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_inpaint()
|
||||
self.cleanup_clip_image()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
@@ -154,6 +156,22 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_inpaint_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.inpaint_tensor is not None:
|
||||
base_inpaint_tensor = x.inpaint_tensor
|
||||
break
|
||||
inpaint_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.inpaint_tensor is None:
|
||||
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
|
||||
else:
|
||||
inpaint_tensors.append(x.inpaint_tensor)
|
||||
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
|
||||
|
||||
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user