mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for traditional LoRa extract using LoCon script
This commit is contained in:
@@ -6,10 +6,12 @@
|
||||
"extract_model": "/path/to/model/to/extract",
|
||||
"output_folder": "/path/to/output/folder",
|
||||
"is_v2": false,
|
||||
"dtype": "fp16",
|
||||
"device": "cpu",
|
||||
"process": [
|
||||
{
|
||||
"filename":"[name]_64_32.safetensors",
|
||||
"dtype": "fp16",
|
||||
"type": "locon",
|
||||
"mode": "fixed",
|
||||
"linear": 64,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from jobs.process import BaseExtractProcess
|
||||
from jobs import BaseJob
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
process_dict = {
|
||||
'locon': 'ExtractLoconProcess',
|
||||
'lora': 'ExtractLoraProcess',
|
||||
}
|
||||
|
||||
|
||||
@@ -16,8 +15,16 @@ class ExtractJob(BaseJob):
|
||||
super().__init__(config)
|
||||
self.base_model_path = self.get_conf('base_model', required=True)
|
||||
self.base_model = None
|
||||
self.base_text_encoder = None
|
||||
self.base_vae = None
|
||||
self.base_unet = None
|
||||
self.extract_model_path = self.get_conf('extract_model', required=True)
|
||||
self.extract_model = None
|
||||
self.extract_text_encoder = None
|
||||
self.extract_vae = None
|
||||
self.extract_unet = None
|
||||
self.dtype = self.get_conf('dtype', 'fp16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.is_v2 = self.get_conf('is_v2', False)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
@@ -30,10 +37,17 @@ class ExtractJob(BaseJob):
|
||||
# load models
|
||||
print(f"Loading models for extraction")
|
||||
print(f" - Loading base model: {self.base_model_path}")
|
||||
# (text_model, vae, unet)
|
||||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||
self.base_text_encoder = self.base_model[0]
|
||||
self.base_vae = self.base_model[1]
|
||||
self.base_unet = self.base_model[2]
|
||||
|
||||
print(f" - Loading extract model: {self.extract_model_path}")
|
||||
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
|
||||
self.extract_text_encoder = self.extract_model[0]
|
||||
self.extract_vae = self.extract_model[1]
|
||||
self.extract_unet = self.extract_model[2]
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
@@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors
|
||||
|
||||
from typing import ForwardRef
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class BaseExtractProcess(BaseProcess):
|
||||
process_id: int
|
||||
@@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess):
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.dtype = self.get_conf('dtype', self.job.dtype)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
def run(self):
|
||||
# here instead of init because child init needs to go first
|
||||
@@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess):
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(self.torch_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, self.output_path, save_meta)
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ class BaseProcess:
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
if key in self.config:
|
||||
value = self.config[key]
|
||||
if as_type is not None:
|
||||
if as_type is not None and value is not None:
|
||||
value = as_type(value)
|
||||
return value
|
||||
elif required:
|
||||
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
|
||||
else:
|
||||
if as_type is not None:
|
||||
if as_type is not None and default is not None:
|
||||
return as_type(default)
|
||||
return default
|
||||
|
||||
|
||||
@@ -37,8 +37,8 @@ class ExtractLoconProcess(BaseExtractProcess):
|
||||
# set modes
|
||||
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
|
||||
raise ValueError(f"Unknown mode: {self.mode}")
|
||||
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
|
||||
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
|
||||
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
|
||||
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type'])
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
52
jobs/process/ExtractLoraProcess.py
Normal file
52
jobs/process/ExtractLoraProcess.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections import OrderedDict
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
|
||||
mode_dict = {
|
||||
'fixed': {
|
||||
'dim': 4,
|
||||
'type': int
|
||||
}
|
||||
}
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
|
||||
|
||||
class ExtractLoraProcess(BaseExtractProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.mode = self.get_conf('mode', 'fixed')
|
||||
|
||||
# set modes
|
||||
if self.mode not in ['fixed']:
|
||||
raise ValueError(f"Unknown mode: {self.mode}")
|
||||
self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type'])
|
||||
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
|
||||
self.sparsity = self.get_conf('sparsity', 0.98)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Running process: {self.mode}, dim: {self.dim}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.mode,
|
||||
self.dim,
|
||||
0,
|
||||
self.job.device,
|
||||
self.use_sparse_bias,
|
||||
self.sparsity,
|
||||
small_conv=False,
|
||||
linear_only=True,
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
self.save(state_dict)
|
||||
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
if suffix is None:
|
||||
suffix = f"_{self.dim}"
|
||||
return super().get_output_path(prefix, suffix)
|
||||
@@ -1,5 +1,6 @@
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .ExtractLoconProcess import ExtractLoconProcess
|
||||
from .ExtractLoraProcess import ExtractLoraProcess
|
||||
from .BaseProcess import BaseProcess
|
||||
from .BaseTrainProcess import BaseTrainProcess
|
||||
from .TrainFineTuneProcess import TrainFineTuneProcess
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -27,6 +27,16 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def get_torch_dtype(dtype_str):
|
||||
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
|
||||
return torch.float
|
||||
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
|
||||
return torch.float16
|
||||
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
|
||||
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
||||
# if name_replace attr in args (may not be)
|
||||
if hasattr(args, "name_replace") and args.name_replace is not None:
|
||||
|
||||
Reference in New Issue
Block a user