mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added ability to load clip pairs randomly from folder. Other small bug fixes
This commit is contained in:
@@ -607,6 +607,8 @@ class DatasetConfig:
|
||||
|
||||
# ip adapter / reference dataset
|
||||
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
|
||||
# get the clip image randomly from the same folder as the image. Useful for folder grouped pairs.
|
||||
self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False)
|
||||
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
|
||||
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
|
||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
@@ -630,11 +631,12 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_vision_unconditional_paths: Union[List[str], None] = None
|
||||
self._clip_vision_embeddings_path: Union[str, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
|
||||
# copy the clip image processor so the dataloader can do it
|
||||
sd = kwargs.get('sd', None)
|
||||
if hasattr(sd.adapter, 'clip_image_processor'):
|
||||
self.clip_image_processor = sd.adapter.clip_image_processor
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# find the control image path
|
||||
clip_image_path = dataset_config.clip_image_path
|
||||
# we are using control images
|
||||
@@ -646,7 +648,11 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
|
||||
self.has_clip_image = True
|
||||
break
|
||||
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
if dataset_config.clip_image_from_same_folder:
|
||||
# assume we have one. We will pull it on load.
|
||||
self.has_clip_image = True
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
|
||||
@@ -732,6 +738,24 @@ class ClipImageFileItemDTOMixin:
|
||||
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._clip_vision_embeddings_path
|
||||
|
||||
def get_new_clip_image_path(self: 'FileItemDTO'):
|
||||
if self.dataset_config.clip_image_from_same_folder:
|
||||
# randomly grab an image path from the same folder
|
||||
pool_folder = os.path.dirname(self.path)
|
||||
# find all images in the folder
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
img_files = []
|
||||
for ext in img_ext_list:
|
||||
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
|
||||
# remove the current image if len is greater than 1
|
||||
if len(img_files) > 1:
|
||||
img_files.remove(self.path)
|
||||
# randomly grab one
|
||||
self.clip_image_path = random.choice(img_files)
|
||||
return self.path
|
||||
else:
|
||||
return self.clip_image_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
|
||||
@@ -744,14 +768,15 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_image_embeds_unconditional = load_file(unconditional_path)
|
||||
|
||||
return
|
||||
clip_image_path = self.get_new_clip_image_path()
|
||||
try:
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
img = Image.open(clip_image_path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
# make a random noise image
|
||||
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.clip_image_path}")
|
||||
print(f"Error loading image: {clip_image_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
@@ -472,8 +472,8 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
|
||||
self.mid_size = self.token_size
|
||||
|
||||
# if pixtral, use cross attn dim for more sparse representation
|
||||
if is_pixtral:
|
||||
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
|
||||
if is_pixtral and self.config.flux_only_double:
|
||||
if is_flux:
|
||||
hidden_size = 3072
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user