mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-31 10:59:47 +00:00
Added extensions and an example extension that merges models
This commit is contained in:
129
extensions/example/ExampleMergeModels.py
Normal file
129
extensions/example/ExampleMergeModels.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type check imports. Prevents circular imports
|
||||
if TYPE_CHECKING:
|
||||
from jobs import ExtensionJob
|
||||
|
||||
|
||||
# extend standard config classes to add weight
|
||||
class ModelInputConfig(ModelConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.weight = kwargs.get('weight', 1.0)
|
||||
# overwrite default dtype unless user specifies otherwise
|
||||
# float 32 will give up better precision on the merging functions
|
||||
self.dtype: str = kwargs.get('dtype', 'float32')
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
# this is our main class process
|
||||
class ExampleMergeModels(BaseExtensionProcess):
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: 'ExtensionJob',
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
||||
# checking requirements. This is called before the run() function
|
||||
# no loading models or anything like that, it is just for setting up the process
|
||||
# all of your process intensive stuff should be done in the run() function
|
||||
# config will have everything from the process item in the config file
|
||||
|
||||
# convince methods exist on BaseProcess to get config values
|
||||
# if required is set to true and the value is not found it will throw an error
|
||||
# you can pass a default value to get_conf() as well if it was not in the config file
|
||||
# as well as a type to cast the value to
|
||||
self.save_path = self.get_conf('save_path', required=True)
|
||||
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
||||
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
||||
|
||||
# build models to merge list
|
||||
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
||||
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
||||
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
||||
# inbuilt config classes located in toolkit.config_modules as well
|
||||
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
||||
# setup is complete. Don't load anything else here, just setup variables and stuff
|
||||
|
||||
# this is the entire run process be sure to call super().run() first
|
||||
def run(self):
|
||||
# always call first
|
||||
super().run()
|
||||
print(f"Running process: {self.__class__.__name__}")
|
||||
|
||||
# let's adjust our weights first to normalize them so the total is 1.0
|
||||
total_weight = sum([model.weight for model in self.models_to_merge])
|
||||
weight_adjust = 1.0 / total_weight
|
||||
for model in self.models_to_merge:
|
||||
model.weight *= weight_adjust
|
||||
|
||||
output_model: StableDiffusion = None
|
||||
# let's do the merge, it is a good idea to use tqdm to show progress
|
||||
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
||||
# setup model class with our helper class
|
||||
sd_model = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config,
|
||||
dtype="float32"
|
||||
)
|
||||
# load the model
|
||||
sd_model.load_model()
|
||||
|
||||
# adjust the weight of the text encoder
|
||||
if isinstance(sd_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for text_encoder in sd_model.text_encoder:
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
else:
|
||||
# normal model
|
||||
for key, value in sd_model.text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
# adjust the weights of the unet
|
||||
for key, value in sd_model.unet.state_dict().items():
|
||||
value *= model_config.weight
|
||||
|
||||
if output_model is None:
|
||||
# use this one as the base
|
||||
output_model = sd_model
|
||||
else:
|
||||
# merge the models
|
||||
# text encoder
|
||||
if isinstance(output_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for i, text_encoder in enumerate(output_model.text_encoder):
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder[i].state_dict()[key]
|
||||
else:
|
||||
# normal model
|
||||
for key, value in output_model.text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder.state_dict()[key]
|
||||
# unet
|
||||
for key, value in output_model.unet.state_dict().items():
|
||||
value += sd_model.unet.state_dict()[key]
|
||||
|
||||
# remove the model to free memory
|
||||
del sd_model
|
||||
flush()
|
||||
|
||||
# merge loop is done, let's save the model
|
||||
print(f"Saving merged model to {self.save_path}")
|
||||
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
||||
print(f"Saved merged model to {self.save_path}")
|
||||
# do cleanup here
|
||||
del output_model
|
||||
flush()
|
||||
Reference in New Issue
Block a user