diff --git a/jobs/MergeJob.py b/jobs/MergeJob.py new file mode 100644 index 00000000..b9e3b87b --- /dev/null +++ b/jobs/MergeJob.py @@ -0,0 +1,29 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from collections import OrderedDict +from jobs import BaseJob +from toolkit.train_tools import get_torch_dtype + +process_dict = { +} + + +class MergeJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.dtype = self.get_conf('dtype', 'fp16') + self.torch_dtype = get_torch_dtype(self.dtype) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index 9c2472f6..688ccfc1 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -1,3 +1,4 @@ from .BaseJob import BaseJob from .ExtractJob import ExtractJob from .TrainJob import TrainJob +from .MergeJob import MergeJob diff --git a/jobs/process/BaseMergeProcess.py b/jobs/process/BaseMergeProcess.py new file mode 100644 index 00000000..d5396dc3 --- /dev/null +++ b/jobs/process/BaseMergeProcess.py @@ -0,0 +1,46 @@ +import os +from collections import OrderedDict + +from safetensors.torch import save_file + +from jobs.process.BaseProcess import BaseProcess +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype + + +class BaseMergeProcess(BaseProcess): + process_id: int + config: OrderedDict + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.output_path = self.get_conf('output_path', required=True) + self.dtype = self.get_conf('dtype', self.job.dtype) + self.torch_dtype = get_torch_dtype(self.dtype) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass + + def save(self, state_dict): + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # save + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(self.torch_dtype) + state_dict[key] = v + + # having issues with meta + save_file(state_dict, self.output_path, save_meta) + + print(f"Saved to {self.output_path}") diff --git a/jobs/process/MergeLoconProcess.py b/jobs/process/MergeLoconProcess.py new file mode 100644 index 00000000..00c70cd2 --- /dev/null +++ b/jobs/process/MergeLoconProcess.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +from toolkit.lycoris_utils import extract_diff +from .BaseExtractProcess import BaseExtractProcess + + +class MergeLoconProcess(BaseExtractProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + super().run() + new_state_dict = {} + raise NotImplementedError("This is not implemented yet") + + + def get_output_path(self, prefix=None, suffix=None): + if suffix is None: + suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" + return super().get_output_path(prefix, suffix) + diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index ef526b7d..90efff03 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess): self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.first_step = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.writer = self.job.writer @@ -462,6 +463,7 @@ class TrainVAEProcess(BaseTrainProcess): self.max_steps = num_steps self.epochs = num_epochs start_step = self.step_num + self.first_step = start_step self.print(f"Training VAE") self.print(f" - Training folder: {self.training_folder}") diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 6160dff7..413aebce 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -4,3 +4,4 @@ from .ExtractLoraProcess import ExtractLoraProcess from .BaseProcess import BaseProcess from .BaseTrainProcess import BaseTrainProcess from .TrainVAEProcess import TrainVAEProcess +from .BaseMergeProcess import BaseMergeProcess