mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-26 23:33:58 +00:00
Adjusted some names and allow to ignore text encoder on extract
This commit is contained in:
@@ -14,15 +14,17 @@ class ExtractJob(BaseJob):
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.base_model_path = self.get_conf('base_model', required=True)
|
||||
self.base_model = None
|
||||
self.base_text_encoder = None
|
||||
self.base_vae = None
|
||||
self.base_unet = None
|
||||
self.model_base = None
|
||||
self.model_base_text_encoder = None
|
||||
self.model_base_vae = None
|
||||
self.model_base_unet = None
|
||||
self.extract_model_path = self.get_conf('extract_model', required=True)
|
||||
self.extract_model = None
|
||||
self.extract_text_encoder = None
|
||||
self.extract_vae = None
|
||||
self.extract_unet = None
|
||||
self.model_extract = None
|
||||
self.model_extract_text_encoder = None
|
||||
self.model_extract_vae = None
|
||||
self.model_extract_unet = None
|
||||
self.extract_unet = self.get_conf('extract_unet', True)
|
||||
self.extract_text_encoder = self.get_conf('extract_text_encoder', True)
|
||||
self.dtype = self.get_conf('dtype', 'fp16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
@@ -38,16 +40,16 @@ class ExtractJob(BaseJob):
|
||||
print(f"Loading models for extraction")
|
||||
print(f" - Loading base model: {self.base_model_path}")
|
||||
# (text_model, vae, unet)
|
||||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||
self.base_text_encoder = self.base_model[0]
|
||||
self.base_vae = self.base_model[1]
|
||||
self.base_unet = self.base_model[2]
|
||||
self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||
self.model_base_text_encoder = self.model_base[0]
|
||||
self.model_base_vae = self.model_base[1]
|
||||
self.model_base_unet = self.model_base[2]
|
||||
|
||||
print(f" - Loading extract model: {self.extract_model_path}")
|
||||
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
|
||||
self.extract_text_encoder = self.extract_model[0]
|
||||
self.extract_vae = self.extract_model[1]
|
||||
self.extract_unet = self.extract_model[2]
|
||||
self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
|
||||
self.model_extract_text_encoder = self.model_extract[0]
|
||||
self.model_extract_vae = self.model_extract[1]
|
||||
self.model_extract_unet = self.model_extract[2]
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
@@ -30,6 +30,8 @@ class BaseExtractProcess(BaseProcess):
|
||||
self.config = config
|
||||
self.dtype = self.get_conf('dtype', self.job.dtype)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet)
|
||||
self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder)
|
||||
|
||||
def run(self):
|
||||
# here instead of init because child init needs to go first
|
||||
|
||||
@@ -45,15 +45,17 @@ class ExtractLoconProcess(BaseExtractProcess):
|
||||
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.job.model_base,
|
||||
self.job.model_extract,
|
||||
self.mode,
|
||||
self.linear_param,
|
||||
self.conv_param,
|
||||
self.job.device,
|
||||
self.use_sparse_bias,
|
||||
self.sparsity,
|
||||
not self.disable_cp
|
||||
not self.disable_cp,
|
||||
extract_unet=self.extract_unet,
|
||||
extract_text_encoder=self.extract_text_encoder
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
|
||||
@@ -44,8 +44,8 @@ class ExtractLoraProcess(BaseExtractProcess):
|
||||
print(f"Running process: {self.mode}, dim: {self.dim}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.job.model_base,
|
||||
self.job.model_extract,
|
||||
self.mode,
|
||||
self.dim,
|
||||
0,
|
||||
@@ -54,6 +54,8 @@ class ExtractLoraProcess(BaseExtractProcess):
|
||||
self.sparsity,
|
||||
small_conv=False,
|
||||
linear_only=True,
|
||||
extract_unet=self.extract_unet,
|
||||
extract_text_encoder=self.extract_text_encoder
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
|
||||
@@ -125,6 +125,8 @@ def extract_diff(
|
||||
sparsity=0.98,
|
||||
small_conv=True,
|
||||
linear_only=False,
|
||||
extract_unet=True,
|
||||
extract_text_encoder=True,
|
||||
):
|
||||
meta = OrderedDict()
|
||||
|
||||
@@ -148,7 +150,15 @@ def extract_diff(
|
||||
"conv_out",
|
||||
]
|
||||
|
||||
if not extract_unet:
|
||||
UNET_TARGET_REPLACE_MODULE = []
|
||||
UNET_TARGET_REPLACE_NAME = []
|
||||
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
|
||||
if not extract_text_encoder:
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = []
|
||||
|
||||
LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user