hardened reading prompts from json

This commit is contained in:
Jaret Burkett
2023-10-15 07:20:33 -06:00
parent 7909b50d24
commit b1a22d0b3e
3 changed files with 22 additions and 1 deletions

View File

@@ -218,6 +218,11 @@ class CaptionProcessingDTOMixin:
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
if prompt_path.endswith('.json'):
# replace any line endings with commas for \n \r \r\n
prompt = prompt.replace('\r\n', ' ')
prompt = prompt.replace('\n', ' ')
prompt = prompt.replace('\r', ' ')
prompt = json.loads(prompt)
if 'caption' in prompt:
prompt = prompt['caption']
@@ -277,6 +282,21 @@ class CaptionProcessingDTOMixin:
# join back together
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
# add random triggers
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
if self.dataset_config.shuffle_tokens:
# shuffle again
token_list = caption.split(',')
# trim whitespace
token_list = [x.strip() for x in token_list]
# remove empty strings
token_list = [x for x in token_list if x]
random.shuffle(token_list)
caption = ', '.join(token_list)
return caption