Add a config flag to trigger fast image size db builder. Add config flag to set unconditional prompt for guidance loss

This commit is contained in:
Jaret Burkett
2025-06-24 08:51:29 -06:00
parent ba1274d99e
commit f3eb1dff42
3 changed files with 15 additions and 10 deletions

View File

@@ -471,6 +471,7 @@ class TrainConfig:
# contrastive loss
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
@@ -837,6 +838,9 @@ class DatasetConfig:
self.controls = [self.controls]
# remove empty strings
self.controls = [control for control in self.controls if control.strip() != '']
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -84,15 +84,16 @@ class FileItemDTO(
video.release()
size_database[file_key] = (width, height, file_signature)
else:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
# process width and height
# try:
# w, h = image_utils.get_image_size(self.path)
# except image_utils.UnknownImageFormat:
# print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
# f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
w, h = img.size
if self.dataset_config.fast_image_size:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
else:
img = exif_transpose(Image.open(self.path))
w, h = img.size
size_database[file_key] = (w, h, file_signature)
self.width: int = w
self.height: int = h