hardened reading prompts from json

This commit is contained in:
Jaret Burkett
2023-10-15 07:20:33 -06:00
parent 7909b50d24
commit b1a22d0b3e
3 changed files with 22 additions and 1 deletions

View File

@@ -3,6 +3,6 @@ from collections import OrderedDict
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.4"
v["version"] = "0.1.0"
software_meta = v

View File

@@ -223,6 +223,7 @@ class DatasetConfig:
self.dataset_path: str = kwargs.get('dataset_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
self.caption_ext: str = kwargs.get('caption_ext', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)

View File

@@ -218,6 +218,11 @@ class CaptionProcessingDTOMixin:
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
if prompt_path.endswith('.json'):
# replace any line endings with commas for \n \r \r\n
prompt = prompt.replace('\r\n', ' ')
prompt = prompt.replace('\n', ' ')
prompt = prompt.replace('\r', ' ')
prompt = json.loads(prompt)
if 'caption' in prompt:
prompt = prompt['caption']
@@ -277,6 +282,21 @@ class CaptionProcessingDTOMixin:
# join back together
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
# add random triggers
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
if self.dataset_config.shuffle_tokens:
# shuffle again
token_list = caption.split(',')
# trim whitespace
token_list = [x.strip() for x in token_list]
# remove empty strings
token_list = [x for x in token_list if x]
random.shuffle(token_list)
caption = ', '.join(token_list)
return caption