Moved some of the job config into base process so it will be easier to extend extensions

This commit is contained in:
Jaret Burkett
2023-08-10 12:14:05 -06:00
parent fbc8a87a05
commit df48f0a843
6 changed files with 43 additions and 30 deletions

View File

@@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
pass
def hook_before_train_loop(self):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file
if self.slider_config.prompt_file:
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
self.print(f"Found {len(self.prompt_txt_list)} prompts.")
if not self.slider_config.prompt_tensors:
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps