Added ability to load clip pairs randomly from folder. Other small bug fixes

This commit is contained in:
Jaret Burkett
2024-10-03 10:03:49 -06:00
parent f05224970f
commit 67e0aca750
4 changed files with 57 additions and 28 deletions

View File

@@ -34,6 +34,7 @@ class GenerateConfig:
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
self.num_repeats = kwargs.get('num_repeats', 1)
self.prompts_in_file = self.prompts
if self.prompts is None:
raise ValueError("Prompts must be set")
@@ -110,30 +111,31 @@ class GenerateProcess(BaseProcess):
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
prompt = self.clean_prompt(prompt)
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)

View File

@@ -607,6 +607,8 @@ class DatasetConfig:
# ip adapter / reference dataset
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
# get the clip image randomly from the same folder as the image. Useful for folder grouped pairs.
self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False)
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
self.replacements: List[str] = kwargs.get('replacements', [])

View File

@@ -1,4 +1,5 @@
import base64
import glob
import hashlib
import json
import math
@@ -630,11 +631,12 @@ class ClipImageFileItemDTOMixin:
self.clip_vision_unconditional_paths: Union[List[str], None] = None
self._clip_vision_embeddings_path: Union[str, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
# copy the clip image processor so the dataloader can do it
sd = kwargs.get('sd', None)
if hasattr(sd.adapter, 'clip_image_processor'):
self.clip_image_processor = sd.adapter.clip_image_processor
if dataset_config.clip_image_path is not None:
# find the control image path
clip_image_path = dataset_config.clip_image_path
# we are using control images
@@ -646,7 +648,11 @@ class ClipImageFileItemDTOMixin:
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
self.has_clip_image = True
break
self.build_clip_imag_augmentation_transform()
if dataset_config.clip_image_from_same_folder:
# assume we have one. We will pull it on load.
self.has_clip_image = True
self.build_clip_imag_augmentation_transform()
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
@@ -732,6 +738,24 @@ class ClipImageFileItemDTOMixin:
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._clip_vision_embeddings_path
def get_new_clip_image_path(self: 'FileItemDTO'):
if self.dataset_config.clip_image_from_same_folder:
# randomly grab an image path from the same folder
pool_folder = os.path.dirname(self.path)
# find all images in the folder
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
img_files = []
for ext in img_ext_list:
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
# remove the current image if len is greater than 1
if len(img_files) > 1:
img_files.remove(self.path)
# randomly grab one
self.clip_image_path = random.choice(img_files)
return self.path
else:
return self.clip_image_path
def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
@@ -744,14 +768,15 @@ class ClipImageFileItemDTOMixin:
self.clip_image_embeds_unconditional = load_file(unconditional_path)
return
clip_image_path = self.get_new_clip_image_path()
try:
img = Image.open(self.clip_image_path).convert('RGB')
img = Image.open(clip_image_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
# make a random noise image
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
print(f"Error: {e}")
print(f"Error loading image: {self.clip_image_path}")
print(f"Error loading image: {clip_image_path}")
img = img.convert('RGB')

View File

@@ -472,8 +472,8 @@ class VisionDirectAdapter(torch.nn.Module):
self.mid_size = self.token_size
# if pixtral, use cross attn dim for more sparse representation
if is_pixtral:
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
if is_pixtral and self.config.flux_only_double:
if is_flux:
hidden_size = 3072
else: