mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Setup the base process for merging things. WIP
This commit is contained in:
29
jobs/MergeJob.py
Normal file
29
jobs/MergeJob.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||||
|
from collections import OrderedDict
|
||||||
|
from jobs import BaseJob
|
||||||
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
|
process_dict = {
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MergeJob(BaseJob):
|
||||||
|
|
||||||
|
def __init__(self, config: OrderedDict):
|
||||||
|
super().__init__(config)
|
||||||
|
self.dtype = self.get_conf('dtype', 'fp16')
|
||||||
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
|
self.is_v2 = self.get_conf('is_v2', False)
|
||||||
|
self.device = self.get_conf('device', 'cpu')
|
||||||
|
|
||||||
|
# loads the processes from the config
|
||||||
|
self.load_processes(process_dict)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
|
||||||
|
print("")
|
||||||
|
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||||
|
|
||||||
|
for process in self.process:
|
||||||
|
process.run()
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
from .BaseJob import BaseJob
|
from .BaseJob import BaseJob
|
||||||
from .ExtractJob import ExtractJob
|
from .ExtractJob import ExtractJob
|
||||||
from .TrainJob import TrainJob
|
from .TrainJob import TrainJob
|
||||||
|
from .MergeJob import MergeJob
|
||||||
|
|||||||
46
jobs/process/BaseMergeProcess.py
Normal file
46
jobs/process/BaseMergeProcess.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from jobs.process.BaseProcess import BaseProcess
|
||||||
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMergeProcess(BaseProcess):
|
||||||
|
process_id: int
|
||||||
|
config: OrderedDict
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process_id: int,
|
||||||
|
job,
|
||||||
|
config: OrderedDict
|
||||||
|
):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
self.output_path = self.get_conf('output_path', required=True)
|
||||||
|
self.dtype = self.get_conf('dtype', self.job.dtype)
|
||||||
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
# implement in child class
|
||||||
|
# be sure to call super().run() first
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self, state_dict):
|
||||||
|
# prepare meta
|
||||||
|
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||||
|
|
||||||
|
# save
|
||||||
|
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||||
|
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(self.torch_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
# having issues with meta
|
||||||
|
save_file(state_dict, self.output_path, save_meta)
|
||||||
|
|
||||||
|
print(f"Saved to {self.output_path}")
|
||||||
20
jobs/process/MergeLoconProcess.py
Normal file
20
jobs/process/MergeLoconProcess.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from toolkit.lycoris_utils import extract_diff
|
||||||
|
from .BaseExtractProcess import BaseExtractProcess
|
||||||
|
|
||||||
|
|
||||||
|
class MergeLoconProcess(BaseExtractProcess):
|
||||||
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
new_state_dict = {}
|
||||||
|
raise NotImplementedError("This is not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_path(self, prefix=None, suffix=None):
|
||||||
|
if suffix is None:
|
||||||
|
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
|
||||||
|
return super().get_output_path(prefix, suffix)
|
||||||
|
|
||||||
@@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||||
|
self.first_step = 0
|
||||||
|
|
||||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||||
self.writer = self.job.writer
|
self.writer = self.job.writer
|
||||||
@@ -462,6 +463,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.max_steps = num_steps
|
self.max_steps = num_steps
|
||||||
self.epochs = num_epochs
|
self.epochs = num_epochs
|
||||||
start_step = self.step_num
|
start_step = self.step_num
|
||||||
|
self.first_step = start_step
|
||||||
|
|
||||||
self.print(f"Training VAE")
|
self.print(f"Training VAE")
|
||||||
self.print(f" - Training folder: {self.training_folder}")
|
self.print(f" - Training folder: {self.training_folder}")
|
||||||
|
|||||||
@@ -4,3 +4,4 @@ from .ExtractLoraProcess import ExtractLoraProcess
|
|||||||
from .BaseProcess import BaseProcess
|
from .BaseProcess import BaseProcess
|
||||||
from .BaseTrainProcess import BaseTrainProcess
|
from .BaseTrainProcess import BaseTrainProcess
|
||||||
from .TrainVAEProcess import TrainVAEProcess
|
from .TrainVAEProcess import TrainVAEProcess
|
||||||
|
from .BaseMergeProcess import BaseMergeProcess
|
||||||
|
|||||||
Reference in New Issue
Block a user