diff --git a/config/examples/extract.example.json b/config/examples/extract.example.json index e08a639f..f42eaf46 100644 --- a/config/examples/extract.example.json +++ b/config/examples/extract.example.json @@ -6,10 +6,12 @@ "extract_model": "/path/to/model/to/extract", "output_folder": "/path/to/output/folder", "is_v2": false, + "dtype": "fp16", "device": "cpu", "process": [ { "filename":"[name]_64_32.safetensors", + "dtype": "fp16", "type": "locon", "mode": "fixed", "linear": 64, diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py index 968b7960..fe1ccd8f 100644 --- a/jobs/ExtractJob.py +++ b/jobs/ExtractJob.py @@ -1,12 +1,11 @@ from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint from collections import OrderedDict -from typing import List - -from jobs.process import BaseExtractProcess from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype process_dict = { 'locon': 'ExtractLoconProcess', + 'lora': 'ExtractLoraProcess', } @@ -16,8 +15,16 @@ class ExtractJob(BaseJob): super().__init__(config) self.base_model_path = self.get_conf('base_model', required=True) self.base_model = None + self.base_text_encoder = None + self.base_vae = None + self.base_unet = None self.extract_model_path = self.get_conf('extract_model', required=True) self.extract_model = None + self.extract_text_encoder = None + self.extract_vae = None + self.extract_unet = None + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) self.output_folder = self.get_conf('output_folder', required=True) self.is_v2 = self.get_conf('is_v2', False) self.device = self.get_conf('device', 'cpu') @@ -30,10 +37,17 @@ class ExtractJob(BaseJob): # load models print(f"Loading models for extraction") print(f" - Loading base model: {self.base_model_path}") + # (text_model, vae, unet) self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + self.base_text_encoder = self.base_model[0] + self.base_vae = self.base_model[1] + self.base_unet = self.base_model[2] print(f" - Loading extract model: {self.extract_model_path}") self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) + self.extract_text_encoder = self.extract_model[0] + self.extract_vae = self.extract_model[1] + self.extract_unet = self.extract_model[2] print("") print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py index 6140e0e8..541170b4 100644 --- a/jobs/process/BaseExtractProcess.py +++ b/jobs/process/BaseExtractProcess.py @@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors from typing import ForwardRef +from toolkit.train_tools import get_torch_dtype + class BaseExtractProcess(BaseProcess): process_id: int @@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess): self.process_id = process_id self.job = job self.config = config + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) def run(self): # here instead of init because child init needs to go first @@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess): # save os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + # having issues with meta save_file(state_dict, self.output_path, save_meta) diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index 5dab3337..e9c32360 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -21,13 +21,13 @@ class BaseProcess: def get_conf(self, key, default=None, required=False, as_type=None): if key in self.config: value = self.config[key] - if as_type is not None: + if as_type is not None and value is not None: value = as_type(value) return value elif required: raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') else: - if as_type is not None: + if as_type is not None and default is not None: return as_type(default) return default diff --git a/jobs/process/ExtractLoconProcess.py b/jobs/process/ExtractLoconProcess.py index 9741bc5d..08e9f00c 100644 --- a/jobs/process/ExtractLoconProcess.py +++ b/jobs/process/ExtractLoconProcess.py @@ -37,8 +37,8 @@ class ExtractLoconProcess(BaseExtractProcess): # set modes if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']: raise ValueError(f"Unknown mode: {self.mode}") - self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type']) - self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type']) + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) def run(self): super().run() diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py new file mode 100644 index 00000000..af1cdde4 --- /dev/null +++ b/jobs/process/ExtractLoraProcess.py @@ -0,0 +1,52 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + +mode_dict = { + 'fixed': { + 'dim': 4, + 'type': int + } +} + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +class ExtractLoraProcess(BaseExtractProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.mode = self.get_conf('mode', 'fixed') + + # set modes + if self.mode not in ['fixed']: + raise ValueError(f"Unknown mode: {self.mode}") + self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type']) + self.use_sparse_bias = self.get_conf('use_sparse_bias', False) + self.sparsity = self.get_conf('sparsity', 0.98) + + def run(self): + super().run() + print(f"Running process: {self.mode}, dim: {self.dim}") + + state_dict, extract_diff_meta = extract_diff( + self.job.base_model, + self.job.extract_model, + self.mode, + self.dim, + 0, + self.job.device, + self.use_sparse_bias, + self.sparsity, + small_conv=False, + linear_only=True, + ) + + self.add_meta(extract_diff_meta) + self.save(state_dict) + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.dim}" + return super().get_output_path(prefix, suffix) diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 8879ed3a..1e307d53 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -1,5 +1,6 @@ from .BaseExtractProcess import BaseExtractProcess from .ExtractLoconProcess import ExtractLoconProcess +from .ExtractLoraProcess import ExtractLoraProcess from .BaseProcess import BaseProcess from .BaseTrainProcess import BaseTrainProcess from .TrainFineTuneProcess import TrainFineTuneProcess diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py index c85db570..493be801 100644 --- a/toolkit/lycoris_utils.py +++ b/toolkit/lycoris_utils.py @@ -67,6 +67,9 @@ def extract_conv( return (extract_weight_A, extract_weight_B, diff), 'low rank' +extra_weights = ['lora_unet_conv_in.alpha', 'lora_unet_conv_in.lora_down.weight', 'lora_unet_conv_in.lora_mid.weight', 'lora_unet_conv_in.lora_up.weight', 'lora_unet_conv_out.alpha', 'lora_unet_conv_out.lora_down.weight', 'lora_unet_conv_out.lora_mid.weight', 'lora_unet_conv_out.lora_up.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_time_embedding_linear_1.alpha', 'lora_unet_time_embedding_linear_1.lora_down.weight', 'lora_unet_time_embedding_linear_1.lora_up.weight', 'lora_unet_time_embedding_linear_2.alpha', 'lora_unet_time_embedding_linear_2.lora_down.weight', 'lora_unet_time_embedding_linear_2.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_up.weight'] + + def extract_linear( weight: Union[torch.Tensor, nn.Parameter], mode='fixed', @@ -120,7 +123,8 @@ def extract_diff( extract_device='cpu', use_bias=False, sparsity=0.98, - small_conv=True + small_conv=True, + linear_only=False, ): meta = OrderedDict() @@ -137,6 +141,13 @@ def extract_diff( "time_embedding.linear_1", "time_embedding.linear_2", ] + if linear_only: + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' @@ -186,6 +197,8 @@ def extract_diff( elif layer == 'Conv2d': is_linear = (child_module.weight.shape[2] == 1 and child_module.weight.shape[3] == 1) + if not is_linear and linear_only: + continue weight, decompose_mode = extract_conv( (child_module.weight - weights[child_name]), mode, @@ -254,6 +267,8 @@ def extract_diff( root_weight.shape[2] == 1 and root_weight.shape[3] == 1 ) + if not is_linear and linear_only: + continue weight, decompose_mode = extract_conv( (root_weight - weights), mode, diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index bd4b4075..4d6e38e4 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -27,6 +27,16 @@ SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_torch_dtype(dtype_str): + if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": + return torch.float + if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": + return torch.float16 + if dtype_str == "bf16" or dtype_str == "bfloat16": + return torch.bfloat16 + return None + + def replace_filewords_prompt(prompt, args: argparse.Namespace): # if name_replace attr in args (may not be) if hasattr(args, "name_replace") and args.name_replace is not None: