mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
hardened reading prompts from json
This commit is contained in:
2
info.py
2
info.py
@@ -3,6 +3,6 @@ from collections import OrderedDict
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.0.4"
|
||||
v["version"] = "0.1.0"
|
||||
|
||||
software_meta = v
|
||||
|
||||
@@ -223,6 +223,7 @@ class DatasetConfig:
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
|
||||
@@ -218,6 +218,11 @@ class CaptionProcessingDTOMixin:
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
if prompt_path.endswith('.json'):
|
||||
# replace any line endings with commas for \n \r \r\n
|
||||
prompt = prompt.replace('\r\n', ' ')
|
||||
prompt = prompt.replace('\n', ' ')
|
||||
prompt = prompt.replace('\r', ' ')
|
||||
|
||||
prompt = json.loads(prompt)
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
@@ -277,6 +282,21 @@ class CaptionProcessingDTOMixin:
|
||||
# join back together
|
||||
caption = ', '.join(token_list)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
|
||||
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
||||
# add random triggers
|
||||
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
# shuffle again
|
||||
token_list = caption.split(',')
|
||||
# trim whitespace
|
||||
token_list = [x.strip() for x in token_list]
|
||||
# remove empty strings
|
||||
token_list = [x for x in token_list if x]
|
||||
random.shuffle(token_list)
|
||||
caption = ', '.join(token_list)
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user