diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py index fe1ccd8f..d710d412 100644 --- a/jobs/ExtractJob.py +++ b/jobs/ExtractJob.py @@ -14,15 +14,17 @@ class ExtractJob(BaseJob): def __init__(self, config: OrderedDict): super().__init__(config) self.base_model_path = self.get_conf('base_model', required=True) - self.base_model = None - self.base_text_encoder = None - self.base_vae = None - self.base_unet = None + self.model_base = None + self.model_base_text_encoder = None + self.model_base_vae = None + self.model_base_unet = None self.extract_model_path = self.get_conf('extract_model', required=True) - self.extract_model = None - self.extract_text_encoder = None - self.extract_vae = None - self.extract_unet = None + self.model_extract = None + self.model_extract_text_encoder = None + self.model_extract_vae = None + self.model_extract_unet = None + self.extract_unet = self.get_conf('extract_unet', True) + self.extract_text_encoder = self.get_conf('extract_text_encoder', True) self.dtype = self.get_conf('dtype', 'fp16') self.torch_dtype = get_torch_dtype(self.dtype) self.output_folder = self.get_conf('output_folder', required=True) @@ -38,16 +40,16 @@ class ExtractJob(BaseJob): print(f"Loading models for extraction") print(f" - Loading base model: {self.base_model_path}") # (text_model, vae, unet) - self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) - self.base_text_encoder = self.base_model[0] - self.base_vae = self.base_model[1] - self.base_unet = self.base_model[2] + self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + self.model_base_text_encoder = self.model_base[0] + self.model_base_vae = self.model_base[1] + self.model_base_unet = self.model_base[2] print(f" - Loading extract model: {self.extract_model_path}") - self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) - self.extract_text_encoder = self.extract_model[0] - self.extract_vae = self.extract_model[1] - self.extract_unet = self.extract_model[2] + self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) + self.model_extract_text_encoder = self.model_extract[0] + self.model_extract_vae = self.model_extract[1] + self.model_extract_unet = self.model_extract[2] print("") print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py index 541170b4..009bb7ba 100644 --- a/jobs/process/BaseExtractProcess.py +++ b/jobs/process/BaseExtractProcess.py @@ -30,6 +30,8 @@ class BaseExtractProcess(BaseProcess): self.config = config self.dtype = self.get_conf('dtype', self.job.dtype) self.torch_dtype = get_torch_dtype(self.dtype) + self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet) + self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder) def run(self): # here instead of init because child init needs to go first diff --git a/jobs/process/ExtractLoconProcess.py b/jobs/process/ExtractLoconProcess.py index e0e16506..b5dac5ed 100644 --- a/jobs/process/ExtractLoconProcess.py +++ b/jobs/process/ExtractLoconProcess.py @@ -45,15 +45,17 @@ class ExtractLoconProcess(BaseExtractProcess): print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") state_dict, extract_diff_meta = extract_diff( - self.job.base_model, - self.job.extract_model, + self.job.model_base, + self.job.model_extract, self.mode, self.linear_param, self.conv_param, self.job.device, self.use_sparse_bias, self.sparsity, - not self.disable_cp + not self.disable_cp, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder ) self.add_meta(extract_diff_meta) diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py index 70ea7e16..343a6bd4 100644 --- a/jobs/process/ExtractLoraProcess.py +++ b/jobs/process/ExtractLoraProcess.py @@ -44,8 +44,8 @@ class ExtractLoraProcess(BaseExtractProcess): print(f"Running process: {self.mode}, dim: {self.dim}") state_dict, extract_diff_meta = extract_diff( - self.job.base_model, - self.job.extract_model, + self.job.model_base, + self.job.model_extract, self.mode, self.dim, 0, @@ -54,6 +54,8 @@ class ExtractLoraProcess(BaseExtractProcess): self.sparsity, small_conv=False, linear_only=True, + extract_unet=self.extract_unet, + extract_text_encoder=self.extract_text_encoder ) self.add_meta(extract_diff_meta) diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py index 493be801..dad5aff8 100644 --- a/toolkit/lycoris_utils.py +++ b/toolkit/lycoris_utils.py @@ -125,6 +125,8 @@ def extract_diff( sparsity=0.98, small_conv=True, linear_only=False, + extract_unet=True, + extract_text_encoder=True, ): meta = OrderedDict() @@ -148,7 +150,15 @@ def extract_diff( "conv_out", ] + if not extract_unet: + UNET_TARGET_REPLACE_MODULE = [] + UNET_TARGET_REPLACE_NAME = [] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + + if not extract_text_encoder: + TEXT_ENCODER_TARGET_REPLACE_MODULE = [] + LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te'