Added ability to load clip pairs randomly from folder. Other small bug fixes

This commit is contained in:
Jaret Burkett
2024-10-03 10:03:49 -06:00
parent f05224970f
commit 67e0aca750
4 changed files with 57 additions and 28 deletions

View File

@@ -607,6 +607,8 @@ class DatasetConfig:
# ip adapter / reference dataset
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
# get the clip image randomly from the same folder as the image. Useful for folder grouped pairs.
self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False)
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
self.replacements: List[str] = kwargs.get('replacements', [])

View File

@@ -1,4 +1,5 @@
import base64
import glob
import hashlib
import json
import math
@@ -630,11 +631,12 @@ class ClipImageFileItemDTOMixin:
self.clip_vision_unconditional_paths: Union[List[str], None] = None
self._clip_vision_embeddings_path: Union[str, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
# copy the clip image processor so the dataloader can do it
sd = kwargs.get('sd', None)
if hasattr(sd.adapter, 'clip_image_processor'):
self.clip_image_processor = sd.adapter.clip_image_processor
if dataset_config.clip_image_path is not None:
# find the control image path
clip_image_path = dataset_config.clip_image_path
# we are using control images
@@ -646,7 +648,11 @@ class ClipImageFileItemDTOMixin:
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
self.has_clip_image = True
break
self.build_clip_imag_augmentation_transform()
if dataset_config.clip_image_from_same_folder:
# assume we have one. We will pull it on load.
self.has_clip_image = True
self.build_clip_imag_augmentation_transform()
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
@@ -732,6 +738,24 @@ class ClipImageFileItemDTOMixin:
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._clip_vision_embeddings_path
def get_new_clip_image_path(self: 'FileItemDTO'):
if self.dataset_config.clip_image_from_same_folder:
# randomly grab an image path from the same folder
pool_folder = os.path.dirname(self.path)
# find all images in the folder
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
img_files = []
for ext in img_ext_list:
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
# remove the current image if len is greater than 1
if len(img_files) > 1:
img_files.remove(self.path)
# randomly grab one
self.clip_image_path = random.choice(img_files)
return self.path
else:
return self.clip_image_path
def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
@@ -744,14 +768,15 @@ class ClipImageFileItemDTOMixin:
self.clip_image_embeds_unconditional = load_file(unconditional_path)
return
clip_image_path = self.get_new_clip_image_path()
try:
img = Image.open(self.clip_image_path).convert('RGB')
img = Image.open(clip_image_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
# make a random noise image
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
print(f"Error: {e}")
print(f"Error loading image: {self.clip_image_path}")
print(f"Error loading image: {clip_image_path}")
img = img.convert('RGB')

View File

@@ -472,8 +472,8 @@ class VisionDirectAdapter(torch.nn.Module):
self.mid_size = self.token_size
# if pixtral, use cross attn dim for more sparse representation
if is_pixtral:
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
if is_pixtral and self.config.flux_only_double:
if is_flux:
hidden_size = 3072
else: