mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added base for using guidance during training. Still not working right.
This commit is contained in:
@@ -351,6 +351,7 @@ class DatasetConfig:
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
|
||||
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
|
||||
@@ -7,7 +7,8 @@ from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -29,6 +30,7 @@ class FileItemDTO(
|
||||
ControlFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
UnconditionalFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
@@ -70,6 +72,7 @@ class FileItemDTO(
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
@@ -82,6 +85,8 @@ class DataLoaderBatchDTO:
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
@@ -138,6 +143,22 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||
|
||||
# add unconditional tensors
|
||||
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_unconditional_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is not None:
|
||||
base_unconditional_tensor = x.unconditional_tensor
|
||||
break
|
||||
unconditional_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unconditional_tensor is None:
|
||||
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
|
||||
else:
|
||||
unconditional_tensor.append(x.unconditional_tensor)
|
||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
@@ -351,6 +351,8 @@ class ImageProcessingDTOMixin:
|
||||
self.load_control_image()
|
||||
if self.has_mask_image:
|
||||
self.load_mask_image()
|
||||
if self.has_unconditional:
|
||||
self.load_unconditional_image()
|
||||
return
|
||||
try:
|
||||
img = Image.open(self.path)
|
||||
@@ -442,6 +444,8 @@ class ImageProcessingDTOMixin:
|
||||
self.load_control_image()
|
||||
if self.has_mask_image:
|
||||
self.load_mask_image()
|
||||
if self.has_unconditional:
|
||||
self.load_unconditional_image()
|
||||
|
||||
|
||||
class ControlFileItemDTOMixin:
|
||||
@@ -661,6 +665,80 @@ class MaskFileItemDTOMixin:
|
||||
self.mask_tensor = None
|
||||
|
||||
|
||||
class UnconditionalFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_unconditional = False
|
||||
self.unconditional_path: Union[str, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latent: Union[torch.Tensor, None] = None
|
||||
self.unconditional_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
|
||||
if dataset_config.unconditional_path is not None:
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
|
||||
self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)
|
||||
self.has_unconditional = True
|
||||
break
|
||||
|
||||
def load_unconditional_image(self: 'FileItemDTO'):
|
||||
try:
|
||||
img = Image.open(self.unconditional_path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.mask_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Unconditional images are not supported for non-bucket datasets")
|
||||
|
||||
self.unconditional_tensor = self.unconditional_transforms(img)
|
||||
|
||||
def cleanup_unconditional(self: 'FileItemDTO'):
|
||||
self.unconditional_tensor = None
|
||||
self.unconditional_latent = None
|
||||
|
||||
|
||||
class PoiFileItemDTOMixin:
|
||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||
# items in the poi will always be inside the image when random cropping
|
||||
|
||||
Reference in New Issue
Block a user