Handle inpainting training for control_lora adapter

This commit is contained in:
Jaret Burkett
2025-03-24 13:17:47 -06:00
parent f10937e6da
commit 45be82d5d6
9 changed files with 257 additions and 23 deletions

View File

@@ -12,7 +12,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
if TYPE_CHECKING:
@@ -34,6 +34,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
InpaintControlFileItemDTOMixin,
ClipImageFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
@@ -108,6 +109,7 @@ class FileItemDTO(
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
self.cleanup_inpaint()
self.cleanup_clip_image()
self.cleanup_mask()
self.cleanup_unconditional()
@@ -154,6 +156,22 @@ class DataLoaderBatchDTO:
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]):
# find one to use as a base
base_inpaint_tensor = None
for x in self.file_items:
if x.inpaint_tensor is not None:
base_inpaint_tensor = x.inpaint_tensor
break
inpaint_tensors = []
for x in self.file_items:
if x.inpaint_tensor is None:
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
else:
inpaint_tensors.append(x.inpaint_tensor)
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]