Added support for traditional LoRa extract using LoCon script

This commit is contained in:
Jaret Burkett
2023-07-12 19:51:40 -06:00
parent 57f14e5ef2
commit 8d6edae9fd
9 changed files with 111 additions and 8 deletions

View File

@@ -6,10 +6,12 @@
"extract_model": "/path/to/model/to/extract",
"output_folder": "/path/to/output/folder",
"is_v2": false,
"dtype": "fp16",
"device": "cpu",
"process": [
{
"filename":"[name]_64_32.safetensors",
"dtype": "fp16",
"type": "locon",
"mode": "fixed",
"linear": 64,

View File

@@ -1,12 +1,11 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess
from jobs import BaseJob
from toolkit.train_tools import get_torch_dtype
process_dict = {
'locon': 'ExtractLoconProcess',
'lora': 'ExtractLoraProcess',
}
@@ -16,8 +15,16 @@ class ExtractJob(BaseJob):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.base_text_encoder = None
self.base_vae = None
self.base_unet = None
self.extract_model_path = self.get_conf('extract_model', required=True)
self.extract_model = None
self.extract_text_encoder = None
self.extract_vae = None
self.extract_unet = None
self.dtype = self.get_conf('dtype', 'fp16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.output_folder = self.get_conf('output_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
@@ -30,10 +37,17 @@ class ExtractJob(BaseJob):
# load models
print(f"Loading models for extraction")
print(f" - Loading base model: {self.base_model_path}")
# (text_model, vae, unet)
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
self.base_text_encoder = self.base_model[0]
self.base_vae = self.base_model[1]
self.base_unet = self.base_model[2]
print(f" - Loading extract model: {self.extract_model_path}")
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
self.extract_text_encoder = self.extract_model[0]
self.extract_vae = self.extract_model[1]
self.extract_unet = self.extract_model[2]
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

View File

@@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors
from typing import ForwardRef
from toolkit.train_tools import get_torch_dtype
class BaseExtractProcess(BaseProcess):
process_id: int
@@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess):
self.process_id = process_id
self.job = job
self.config = config
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
def run(self):
# here instead of init because child init needs to go first
@@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess):
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(self.torch_dtype)
state_dict[key] = v
# having issues with meta
save_file(state_dict, self.output_path, save_meta)

View File

@@ -21,13 +21,13 @@ class BaseProcess:
def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None:
if as_type is not None and value is not None:
value = as_type(value)
return value
elif required:
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
else:
if as_type is not None:
if as_type is not None and default is not None:
return as_type(default)
return default

View File

@@ -37,8 +37,8 @@ class ExtractLoconProcess(BaseExtractProcess):
# set modes
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
raise ValueError(f"Unknown mode: {self.mode}")
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type'])
def run(self):
super().run()

View File

@@ -0,0 +1,52 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
mode_dict = {
'fixed': {
'dim': 4,
'type': int
}
}
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
class ExtractLoraProcess(BaseExtractProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
# set modes
if self.mode not in ['fixed']:
raise ValueError(f"Unknown mode: {self.mode}")
self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type'])
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
def run(self):
super().run()
print(f"Running process: {self.mode}, dim: {self.dim}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.dim,
0,
self.job.device,
self.use_sparse_bias,
self.sparsity,
small_conv=False,
linear_only=True,
)
self.add_meta(extract_diff_meta)
self.save(state_dict)
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.dim}"
return super().get_output_path(prefix, suffix)

View File

@@ -1,5 +1,6 @@
from .BaseExtractProcess import BaseExtractProcess
from .ExtractLoconProcess import ExtractLoconProcess
from .ExtractLoraProcess import ExtractLoraProcess
from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess
from .TrainFineTuneProcess import TrainFineTuneProcess

File diff suppressed because one or more lines are too long

View File

@@ -27,6 +27,16 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def get_torch_dtype(dtype_str):
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
return torch.float
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
return torch.float16
if dtype_str == "bf16" or dtype_str == "bfloat16":
return torch.bfloat16
return None
def replace_filewords_prompt(prompt, args: argparse.Namespace):
# if name_replace attr in args (may not be)
if hasattr(args, "name_replace") and args.name_replace is not None: