Adjusted some names and allow to ignore text encoder on extract

This commit is contained in:
Jaret Burkett
2023-07-16 12:52:39 -06:00
parent a61914ddd9
commit 9df7af1720
5 changed files with 39 additions and 21 deletions

View File

@@ -14,15 +14,17 @@ class ExtractJob(BaseJob):
def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.base_text_encoder = None
self.base_vae = None
self.base_unet = None
self.model_base = None
self.model_base_text_encoder = None
self.model_base_vae = None
self.model_base_unet = None
self.extract_model_path = self.get_conf('extract_model', required=True)
self.extract_model = None
self.extract_text_encoder = None
self.extract_vae = None
self.extract_unet = None
self.model_extract = None
self.model_extract_text_encoder = None
self.model_extract_vae = None
self.model_extract_unet = None
self.extract_unet = self.get_conf('extract_unet', True)
self.extract_text_encoder = self.get_conf('extract_text_encoder', True)
self.dtype = self.get_conf('dtype', 'fp16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.output_folder = self.get_conf('output_folder', required=True)
@@ -38,16 +40,16 @@ class ExtractJob(BaseJob):
print(f"Loading models for extraction")
print(f" - Loading base model: {self.base_model_path}")
# (text_model, vae, unet)
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
self.base_text_encoder = self.base_model[0]
self.base_vae = self.base_model[1]
self.base_unet = self.base_model[2]
self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
self.model_base_text_encoder = self.model_base[0]
self.model_base_vae = self.model_base[1]
self.model_base_unet = self.model_base[2]
print(f" - Loading extract model: {self.extract_model_path}")
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
self.extract_text_encoder = self.extract_model[0]
self.extract_vae = self.extract_model[1]
self.extract_unet = self.extract_model[2]
self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
self.model_extract_text_encoder = self.model_extract[0]
self.model_extract_vae = self.model_extract[1]
self.model_extract_unet = self.model_extract[2]
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")