mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
130 lines
5.5 KiB
Python
130 lines
5.5 KiB
Python
import torch
|
|
import gc
|
|
from collections import OrderedDict
|
|
from typing import TYPE_CHECKING
|
|
from jobs.process import BaseExtensionProcess
|
|
from toolkit.config_modules import ModelConfig
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
from toolkit.train_tools import get_torch_dtype
|
|
from tqdm import tqdm
|
|
|
|
# Type check imports. Prevents circular imports
|
|
if TYPE_CHECKING:
|
|
from jobs import ExtensionJob
|
|
|
|
|
|
# extend standard config classes to add weight
|
|
class ModelInputConfig(ModelConfig):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.weight = kwargs.get('weight', 1.0)
|
|
# overwrite default dtype unless user specifies otherwise
|
|
# float 32 will give up better precision on the merging functions
|
|
self.dtype: str = kwargs.get('dtype', 'float32')
|
|
|
|
|
|
def flush():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
# this is our main class process
|
|
class ExampleMergeModels(BaseExtensionProcess):
|
|
def __init__(
|
|
self,
|
|
process_id: int,
|
|
job: 'ExtensionJob',
|
|
config: OrderedDict
|
|
):
|
|
super().__init__(process_id, job, config)
|
|
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
|
# checking requirements. This is called before the run() function
|
|
# no loading models or anything like that, it is just for setting up the process
|
|
# all of your process intensive stuff should be done in the run() function
|
|
# config will have everything from the process item in the config file
|
|
|
|
# convince methods exist on BaseProcess to get config values
|
|
# if required is set to true and the value is not found it will throw an error
|
|
# you can pass a default value to get_conf() as well if it was not in the config file
|
|
# as well as a type to cast the value to
|
|
self.save_path = self.get_conf('save_path', required=True)
|
|
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
|
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
|
|
|
# build models to merge list
|
|
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
|
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
|
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
|
# inbuilt config classes located in toolkit.config_modules as well
|
|
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
|
# setup is complete. Don't load anything else here, just setup variables and stuff
|
|
|
|
# this is the entire run process be sure to call super().run() first
|
|
def run(self):
|
|
# always call first
|
|
super().run()
|
|
print(f"Running process: {self.__class__.__name__}")
|
|
|
|
# let's adjust our weights first to normalize them so the total is 1.0
|
|
total_weight = sum([model.weight for model in self.models_to_merge])
|
|
weight_adjust = 1.0 / total_weight
|
|
for model in self.models_to_merge:
|
|
model.weight *= weight_adjust
|
|
|
|
output_model: StableDiffusion = None
|
|
# let's do the merge, it is a good idea to use tqdm to show progress
|
|
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
|
# setup model class with our helper class
|
|
sd_model = StableDiffusion(
|
|
device=self.device,
|
|
model_config=model_config,
|
|
dtype="float32"
|
|
)
|
|
# load the model
|
|
sd_model.load_model()
|
|
|
|
# adjust the weight of the text encoder
|
|
if isinstance(sd_model.text_encoder, list):
|
|
# sdxl model
|
|
for text_encoder in sd_model.text_encoder:
|
|
for key, value in text_encoder.state_dict().items():
|
|
value *= model_config.weight
|
|
else:
|
|
# normal model
|
|
for key, value in sd_model.text_encoder.state_dict().items():
|
|
value *= model_config.weight
|
|
# adjust the weights of the unet
|
|
for key, value in sd_model.unet.state_dict().items():
|
|
value *= model_config.weight
|
|
|
|
if output_model is None:
|
|
# use this one as the base
|
|
output_model = sd_model
|
|
else:
|
|
# merge the models
|
|
# text encoder
|
|
if isinstance(output_model.text_encoder, list):
|
|
# sdxl model
|
|
for i, text_encoder in enumerate(output_model.text_encoder):
|
|
for key, value in text_encoder.state_dict().items():
|
|
value += sd_model.text_encoder[i].state_dict()[key]
|
|
else:
|
|
# normal model
|
|
for key, value in output_model.text_encoder.state_dict().items():
|
|
value += sd_model.text_encoder.state_dict()[key]
|
|
# unet
|
|
for key, value in output_model.unet.state_dict().items():
|
|
value += sd_model.unet.state_dict()[key]
|
|
|
|
# remove the model to free memory
|
|
del sd_model
|
|
flush()
|
|
|
|
# merge loop is done, let's save the model
|
|
print(f"Saving merged model to {self.save_path}")
|
|
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
|
print(f"Saved merged model to {self.save_path}")
|
|
# do cleanup here
|
|
del output_model
|
|
flush()
|