diff --git a/info.py b/info.py index 81cffcad..9f2f0a97 100644 --- a/info.py +++ b/info.py @@ -3,6 +3,6 @@ from collections import OrderedDict v = OrderedDict() v["name"] = "ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit" -v["version"] = "0.0.4" +v["version"] = "0.1.0" software_meta = v diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index be0e6725..904084e8 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -223,6 +223,7 @@ class DatasetConfig: self.dataset_path: str = kwargs.get('dataset_path', None) self.default_caption: str = kwargs.get('default_caption', None) + self.random_triggers: List[str] = kwargs.get('random_triggers', []) self.caption_ext: str = kwargs.get('caption_ext', None) self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 00cb536e..444a3c32 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -218,6 +218,11 @@ class CaptionProcessingDTOMixin: with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() if prompt_path.endswith('.json'): + # replace any line endings with commas for \n \r \r\n + prompt = prompt.replace('\r\n', ' ') + prompt = prompt.replace('\n', ' ') + prompt = prompt.replace('\r', ' ') + prompt = json.loads(prompt) if 'caption' in prompt: prompt = prompt['caption'] @@ -277,6 +282,21 @@ class CaptionProcessingDTOMixin: # join back together caption = ', '.join(token_list) caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + + if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0: + # add random triggers + caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption + + if self.dataset_config.shuffle_tokens: + # shuffle again + token_list = caption.split(',') + # trim whitespace + token_list = [x.strip() for x in token_list] + # remove empty strings + token_list = [x for x in token_list if x] + random.shuffle(token_list) + caption = ', '.join(token_list) + return caption