diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d8c6c6bb..3268aa13 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -748,6 +748,7 @@ class DatasetConfig: self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes + self.use_short_captions: bool = kwargs.get('use_short_captions', False) # if true, will use 'caption_short' from json self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 395747e6..255e70f2 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -308,6 +308,8 @@ class CaptionProcessingDTOMixin: self.raw_caption = caption_dict[self.path]["caption"] if 'caption_short' in caption_dict[self.path]: self.raw_caption_short = caption_dict[self.path]["caption_short"] + if self.dataset_config.use_short_captions: + self.raw_caption = caption_dict[self.path]["caption_short"] else: # see if prompt file exists path_no_ext = os.path.splitext(self.path)[0] @@ -330,7 +332,8 @@ class CaptionProcessingDTOMixin: prompt = prompt_json['caption'] if 'caption_short' in prompt_json: short_caption = prompt_json['caption_short'] - + if self.dataset_config.use_short_captions: + prompt = short_caption if 'extra_values' in prompt_json: self.extra_values = prompt_json['extra_values']