Setup the base process for merging things. WIP

This commit is contained in:
Jaret Burkett
2023-07-20 07:39:31 -06:00
parent 557732e7ff
commit c29b9d075f
6 changed files with 99 additions and 0 deletions

29
jobs/MergeJob.py Normal file
View File

@@ -0,0 +1,29 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from collections import OrderedDict
from jobs import BaseJob
from toolkit.train_tools import get_torch_dtype
process_dict = {
}
class MergeJob(BaseJob):
def __init__(self, config: OrderedDict):
super().__init__(config)
self.dtype = self.get_conf('dtype', 'fp16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -1,3 +1,4 @@
from .BaseJob import BaseJob from .BaseJob import BaseJob
from .ExtractJob import ExtractJob from .ExtractJob import ExtractJob
from .TrainJob import TrainJob from .TrainJob import TrainJob
from .MergeJob import MergeJob

View File

@@ -0,0 +1,46 @@
import os
from collections import OrderedDict
from safetensors.torch import save_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors
from toolkit.train_tools import get_torch_dtype
class BaseMergeProcess(BaseProcess):
process_id: int
config: OrderedDict
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.output_path = self.get_conf('output_path', required=True)
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
def run(self):
# implement in child class
# be sure to call super().run() first
pass
def save(self, state_dict):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(self.torch_dtype)
state_dict[key] = v
# having issues with meta
save_file(state_dict, self.output_path, save_meta)
print(f"Saved to {self.output_path}")

View File

@@ -0,0 +1,20 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
class MergeLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
def run(self):
super().run()
new_state_dict = {}
raise NotImplementedError("This is not implemented yet")
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
return super().get_output_path(prefix, suffix)

View File

@@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.first_step = 0
self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer self.writer = self.job.writer
@@ -462,6 +463,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.max_steps = num_steps self.max_steps = num_steps
self.epochs = num_epochs self.epochs = num_epochs
start_step = self.step_num start_step = self.step_num
self.first_step = start_step
self.print(f"Training VAE") self.print(f"Training VAE")
self.print(f" - Training folder: {self.training_folder}") self.print(f" - Training folder: {self.training_folder}")

View File

@@ -4,3 +4,4 @@ from .ExtractLoraProcess import ExtractLoraProcess
from .BaseProcess import BaseProcess from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess from .BaseTrainProcess import BaseTrainProcess
from .TrainVAEProcess import TrainVAEProcess from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess