From f3eb1dff4288f1baaeced86d5fe5214ead9c44a8 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 24 Jun 2025 08:51:29 -0600 Subject: [PATCH] Add a config flag to trigger fast image size db builder. Add config flag to set unconditional prompt for guidance loss --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- toolkit/config_modules.py | 4 ++++ toolkit/data_transfer_object/data_loader.py | 19 ++++++++++--------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1722f236..d5a1102e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -147,7 +147,7 @@ class SDTrainer(BaseSDTrainProcess): # cache unconditional embeds (blank prompt) with torch.no_grad(): self.unconditional_embeds = self.sd.encode_prompt( - [''], + [self.train_config.unconditional_prompt], long_prompts=self.do_long_prompts ).to( self.device_torch, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b53b6613..aeca126e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -471,6 +471,7 @@ class TrainConfig: # contrastive loss self.do_guidance_loss = kwargs.get('do_guidance_loss', False) self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) + self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') if isinstance(self.guidance_loss_target, tuple): self.guidance_loss_target = list(self.guidance_loss_target) @@ -837,6 +838,9 @@ class DatasetConfig: self.controls = [self.controls] # remove empty strings self.controls = [control for control in self.controls if control.strip() != ''] + + # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing + self.fast_image_size: bool = kwargs.get('fast_image_size', False) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index ba235e91..0c7d7562 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -84,15 +84,16 @@ class FileItemDTO( video.release() size_database[file_key] = (width, height, file_signature) else: - # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. - # process width and height - # try: - # w, h = image_utils.get_image_size(self.path) - # except image_utils.UnknownImageFormat: - # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - # f'This process is faster for png, jpeg') - img = exif_transpose(Image.open(self.path)) - w, h = img.size + if self.dataset_config.fast_image_size: + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. + try: + w, h = image_utils.get_image_size(self.path) + except image_utils.UnknownImageFormat: + print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + f'This process is faster for png, jpeg') + else: + img = exif_transpose(Image.open(self.path)) + w, h = img.size size_database[file_key] = (w, h, file_signature) self.width: int = w self.height: int = h