various bug fixes. Created an contextual alpha mask module to calculate alpha mask

This commit is contained in:
Jaret Burkett
2024-01-18 16:34:27 -07:00
parent 86c70a2a1f
commit f17ad8d794
7 changed files with 93 additions and 28 deletions

View File

@@ -186,6 +186,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
sample_config = self.first_sample_config if is_first else self.sample_config
start_seed = sample_config.seed
current_seed = start_seed
test_image_paths = []
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
test_image_path_list = self.adapter_config.test_img_path.split(',')
test_image_path_list = [p.strip() for p in test_image_path_list]
test_image_path_list = [p for p in test_image_path_list if p != '']
# divide up images so they are evenly distributed across prompts
for i in range(len(sample_config.prompts)):
test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
@@ -219,7 +229,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
extra_args = {}
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
extra_args['adapter_image_path'] = test_image_paths[i]
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt