diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 1061f65a..e0cb32d8 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -34,6 +34,7 @@ class GenerateConfig: self.compile = kwargs.get('compile', False) self.ext = kwargs.get('ext', 'png') self.prompt_file = kwargs.get('prompt_file', False) + self.num_repeats = kwargs.get('num_repeats', 1) self.prompts_in_file = self.prompts if self.prompts is None: raise ValueError("Prompts must be set") @@ -110,30 +111,31 @@ class GenerateProcess(BaseProcess): print(f"Generating {len(self.generate_config.prompts)} images") # build prompt image configs prompt_image_configs = [] - for prompt in self.generate_config.prompts: - width = self.generate_config.width - height = self.generate_config.height - prompt = self.clean_prompt(prompt) + for _ in range(self.generate_config.num_repeats): + for prompt in self.generate_config.prompts: + width = self.generate_config.width + height = self.generate_config.height + # prompt = self.clean_prompt(prompt) - if self.generate_config.size_list is not None: - # randomly select a size - width, height = random.choice(self.generate_config.size_list) + if self.generate_config.size_list is not None: + # randomly select a size + width, height = random.choice(self.generate_config.size_list) - prompt_image_configs.append(GenerateImageConfig( - prompt=prompt, - prompt_2=self.generate_config.prompt_2, - width=width, - height=height, - num_inference_steps=self.generate_config.sample_steps, - guidance_scale=self.generate_config.guidance_scale, - negative_prompt=self.generate_config.neg, - negative_prompt_2=self.generate_config.neg_2, - seed=self.generate_config.seed, - guidance_rescale=self.generate_config.guidance_rescale, - output_ext=self.generate_config.ext, - output_folder=self.output_folder, - add_prompt_file=self.generate_config.prompt_file - )) + prompt_image_configs.append(GenerateImageConfig( + prompt=prompt, + prompt_2=self.generate_config.prompt_2, + width=width, + height=height, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + negative_prompt=self.generate_config.neg, + negative_prompt_2=self.generate_config.neg_2, + seed=self.generate_config.seed, + guidance_rescale=self.generate_config.guidance_rescale, + output_ext=self.generate_config.ext, + output_folder=self.output_folder, + add_prompt_file=self.generate_config.prompt_file + )) # generate images self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ad8baa4a..3060bb5a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -607,6 +607,8 @@ class DatasetConfig: # ip adapter / reference dataset self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc + # get the clip image randomly from the same folder as the image. Useful for folder grouped pairs. + self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False) self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None) self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False) self.replacements: List[str] = kwargs.get('replacements', []) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index ca2a6815..946c3d27 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,4 +1,5 @@ import base64 +import glob import hashlib import json import math @@ -630,11 +631,12 @@ class ClipImageFileItemDTOMixin: self.clip_vision_unconditional_paths: Union[List[str], None] = None self._clip_vision_embeddings_path: Union[str, None] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) - if dataset_config.clip_image_path is not None: + if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder: # copy the clip image processor so the dataloader can do it sd = kwargs.get('sd', None) if hasattr(sd.adapter, 'clip_image_processor'): self.clip_image_processor = sd.adapter.clip_image_processor + if dataset_config.clip_image_path is not None: # find the control image path clip_image_path = dataset_config.clip_image_path # we are using control images @@ -646,7 +648,11 @@ class ClipImageFileItemDTOMixin: self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext) self.has_clip_image = True break - + self.build_clip_imag_augmentation_transform() + + if dataset_config.clip_image_from_same_folder: + # assume we have one. We will pull it on load. + self.has_clip_image = True self.build_clip_imag_augmentation_transform() def build_clip_imag_augmentation_transform(self: 'FileItemDTO'): @@ -732,6 +738,24 @@ class ClipImageFileItemDTOMixin: self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') return self._clip_vision_embeddings_path + + def get_new_clip_image_path(self: 'FileItemDTO'): + if self.dataset_config.clip_image_from_same_folder: + # randomly grab an image path from the same folder + pool_folder = os.path.dirname(self.path) + # find all images in the folder + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + img_files = [] + for ext in img_ext_list: + img_files += glob.glob(os.path.join(pool_folder, f'*{ext}')) + # remove the current image if len is greater than 1 + if len(img_files) > 1: + img_files.remove(self.path) + # randomly grab one + self.clip_image_path = random.choice(img_files) + return self.path + else: + return self.clip_image_path def load_clip_image(self: 'FileItemDTO'): is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) @@ -744,14 +768,15 @@ class ClipImageFileItemDTOMixin: self.clip_image_embeds_unconditional = load_file(unconditional_path) return + clip_image_path = self.get_new_clip_image_path() try: - img = Image.open(self.clip_image_path).convert('RGB') + img = Image.open(clip_image_path).convert('RGB') img = exif_transpose(img) except Exception as e: # make a random noise image img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution)) print(f"Error: {e}") - print(f"Error loading image: {self.clip_image_path}") + print(f"Error loading image: {clip_image_path}") img = img.convert('RGB') diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 83407c9d..d3f396b0 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -472,8 +472,8 @@ class VisionDirectAdapter(torch.nn.Module): self.mid_size = self.token_size - # if pixtral, use cross attn dim for more sparse representation - if is_pixtral: + # if pixtral, use cross attn dim for more sparse representation if only doing double transformers + if is_pixtral and self.config.flux_only_double: if is_flux: hidden_size = 3072 else: