Adjusted some names and allow to ignore text encoder on extract

This commit is contained in:
Jaret Burkett
2023-07-16 12:52:39 -06:00
parent a61914ddd9
commit 9df7af1720
5 changed files with 39 additions and 21 deletions

View File

@@ -14,15 +14,17 @@ class ExtractJob(BaseJob):
def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.base_text_encoder = None
self.base_vae = None
self.base_unet = None
self.model_base = None
self.model_base_text_encoder = None
self.model_base_vae = None
self.model_base_unet = None
self.extract_model_path = self.get_conf('extract_model', required=True)
self.extract_model = None
self.extract_text_encoder = None
self.extract_vae = None
self.extract_unet = None
self.model_extract = None
self.model_extract_text_encoder = None
self.model_extract_vae = None
self.model_extract_unet = None
self.extract_unet = self.get_conf('extract_unet', True)
self.extract_text_encoder = self.get_conf('extract_text_encoder', True)
self.dtype = self.get_conf('dtype', 'fp16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.output_folder = self.get_conf('output_folder', required=True)
@@ -38,16 +40,16 @@ class ExtractJob(BaseJob):
print(f"Loading models for extraction")
print(f" - Loading base model: {self.base_model_path}")
# (text_model, vae, unet)
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
self.base_text_encoder = self.base_model[0]
self.base_vae = self.base_model[1]
self.base_unet = self.base_model[2]
self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
self.model_base_text_encoder = self.model_base[0]
self.model_base_vae = self.model_base[1]
self.model_base_unet = self.model_base[2]
print(f" - Loading extract model: {self.extract_model_path}")
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
self.extract_text_encoder = self.extract_model[0]
self.extract_vae = self.extract_model[1]
self.extract_unet = self.extract_model[2]
self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
self.model_extract_text_encoder = self.model_extract[0]
self.model_extract_vae = self.model_extract[1]
self.model_extract_unet = self.model_extract[2]
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

View File

@@ -30,6 +30,8 @@ class BaseExtractProcess(BaseProcess):
self.config = config
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet)
self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder)
def run(self):
# here instead of init because child init needs to go first

View File

@@ -45,15 +45,17 @@ class ExtractLoconProcess(BaseExtractProcess):
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.job.model_base,
self.job.model_extract,
self.mode,
self.linear_param,
self.conv_param,
self.job.device,
self.use_sparse_bias,
self.sparsity,
not self.disable_cp
not self.disable_cp,
extract_unet=self.extract_unet,
extract_text_encoder=self.extract_text_encoder
)
self.add_meta(extract_diff_meta)

View File

@@ -44,8 +44,8 @@ class ExtractLoraProcess(BaseExtractProcess):
print(f"Running process: {self.mode}, dim: {self.dim}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.job.model_base,
self.job.model_extract,
self.mode,
self.dim,
0,
@@ -54,6 +54,8 @@ class ExtractLoraProcess(BaseExtractProcess):
self.sparsity,
small_conv=False,
linear_only=True,
extract_unet=self.extract_unet,
extract_text_encoder=self.extract_text_encoder
)
self.add_meta(extract_diff_meta)

View File

@@ -125,6 +125,8 @@ def extract_diff(
sparsity=0.98,
small_conv=True,
linear_only=False,
extract_unet=True,
extract_text_encoder=True,
):
meta = OrderedDict()
@@ -148,7 +150,15 @@ def extract_diff(
"conv_out",
]
if not extract_unet:
UNET_TARGET_REPLACE_MODULE = []
UNET_TARGET_REPLACE_NAME = []
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
if not extract_text_encoder:
TEXT_ENCODER_TARGET_REPLACE_MODULE = []
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'