mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to load clip pairs randomly from folder. Other small bug fixes
This commit is contained in:
@@ -34,6 +34,7 @@ class GenerateConfig:
|
||||
self.compile = kwargs.get('compile', False)
|
||||
self.ext = kwargs.get('ext', 'png')
|
||||
self.prompt_file = kwargs.get('prompt_file', False)
|
||||
self.num_repeats = kwargs.get('num_repeats', 1)
|
||||
self.prompts_in_file = self.prompts
|
||||
if self.prompts is None:
|
||||
raise ValueError("Prompts must be set")
|
||||
@@ -110,30 +111,31 @@ class GenerateProcess(BaseProcess):
|
||||
print(f"Generating {len(self.generate_config.prompts)} images")
|
||||
# build prompt image configs
|
||||
prompt_image_configs = []
|
||||
for prompt in self.generate_config.prompts:
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
prompt = self.clean_prompt(prompt)
|
||||
for _ in range(self.generate_config.num_repeats):
|
||||
for prompt in self.generate_config.prompts:
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
# prompt = self.clean_prompt(prompt)
|
||||
|
||||
if self.generate_config.size_list is not None:
|
||||
# randomly select a size
|
||||
width, height = random.choice(self.generate_config.size_list)
|
||||
if self.generate_config.size_list is not None:
|
||||
# randomly select a size
|
||||
width, height = random.choice(self.generate_config.size_list)
|
||||
|
||||
prompt_image_configs.append(GenerateImageConfig(
|
||||
prompt=prompt,
|
||||
prompt_2=self.generate_config.prompt_2,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
negative_prompt_2=self.generate_config.neg_2,
|
||||
seed=self.generate_config.seed,
|
||||
guidance_rescale=self.generate_config.guidance_rescale,
|
||||
output_ext=self.generate_config.ext,
|
||||
output_folder=self.output_folder,
|
||||
add_prompt_file=self.generate_config.prompt_file
|
||||
))
|
||||
prompt_image_configs.append(GenerateImageConfig(
|
||||
prompt=prompt,
|
||||
prompt_2=self.generate_config.prompt_2,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
negative_prompt_2=self.generate_config.neg_2,
|
||||
seed=self.generate_config.seed,
|
||||
guidance_rescale=self.generate_config.guidance_rescale,
|
||||
output_ext=self.generate_config.ext,
|
||||
output_folder=self.output_folder,
|
||||
add_prompt_file=self.generate_config.prompt_file
|
||||
))
|
||||
# generate images
|
||||
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
|
||||
|
||||
|
||||
@@ -607,6 +607,8 @@ class DatasetConfig:
|
||||
|
||||
# ip adapter / reference dataset
|
||||
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
|
||||
# get the clip image randomly from the same folder as the image. Useful for folder grouped pairs.
|
||||
self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False)
|
||||
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
|
||||
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
|
||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
@@ -630,11 +631,12 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_vision_unconditional_paths: Union[List[str], None] = None
|
||||
self._clip_vision_embeddings_path: Union[str, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
|
||||
# copy the clip image processor so the dataloader can do it
|
||||
sd = kwargs.get('sd', None)
|
||||
if hasattr(sd.adapter, 'clip_image_processor'):
|
||||
self.clip_image_processor = sd.adapter.clip_image_processor
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# find the control image path
|
||||
clip_image_path = dataset_config.clip_image_path
|
||||
# we are using control images
|
||||
@@ -646,7 +648,11 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
|
||||
self.has_clip_image = True
|
||||
break
|
||||
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
if dataset_config.clip_image_from_same_folder:
|
||||
# assume we have one. We will pull it on load.
|
||||
self.has_clip_image = True
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
|
||||
@@ -732,6 +738,24 @@ class ClipImageFileItemDTOMixin:
|
||||
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._clip_vision_embeddings_path
|
||||
|
||||
def get_new_clip_image_path(self: 'FileItemDTO'):
|
||||
if self.dataset_config.clip_image_from_same_folder:
|
||||
# randomly grab an image path from the same folder
|
||||
pool_folder = os.path.dirname(self.path)
|
||||
# find all images in the folder
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
img_files = []
|
||||
for ext in img_ext_list:
|
||||
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
|
||||
# remove the current image if len is greater than 1
|
||||
if len(img_files) > 1:
|
||||
img_files.remove(self.path)
|
||||
# randomly grab one
|
||||
self.clip_image_path = random.choice(img_files)
|
||||
return self.path
|
||||
else:
|
||||
return self.clip_image_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
|
||||
@@ -744,14 +768,15 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_image_embeds_unconditional = load_file(unconditional_path)
|
||||
|
||||
return
|
||||
clip_image_path = self.get_new_clip_image_path()
|
||||
try:
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
img = Image.open(clip_image_path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
# make a random noise image
|
||||
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.clip_image_path}")
|
||||
print(f"Error loading image: {clip_image_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
|
||||
|
||||
@@ -472,8 +472,8 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
|
||||
self.mid_size = self.token_size
|
||||
|
||||
# if pixtral, use cross attn dim for more sparse representation
|
||||
if is_pixtral:
|
||||
# if pixtral, use cross attn dim for more sparse representation if only doing double transformers
|
||||
if is_pixtral and self.config.flux_only_double:
|
||||
if is_flux:
|
||||
hidden_size = 3072
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user