Added extensions and an example extension that merges models

This commit is contained in:
Jaret Burkett
2023-08-04 09:37:24 -06:00
parent b865ac8b24
commit 7e4e660663
14 changed files with 366 additions and 24 deletions

View File

@@ -0,0 +1,129 @@
import torch
import gc
from collections import OrderedDict
from typing import TYPE_CHECKING
from jobs.process import BaseExtensionProcess
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
from tqdm import tqdm
# Type check imports. Prevents circular imports
if TYPE_CHECKING:
from jobs import ExtensionJob
# extend standard config classes to add weight
class ModelInputConfig(ModelConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.weight = kwargs.get('weight', 1.0)
# overwrite default dtype unless user specifies otherwise
# float 32 will give up better precision on the merging functions
self.dtype: str = kwargs.get('dtype', 'float32')
def flush():
torch.cuda.empty_cache()
gc.collect()
# this is our main class process
class ExampleMergeModels(BaseExtensionProcess):
def __init__(
self,
process_id: int,
job: 'ExtensionJob',
config: OrderedDict
):
super().__init__(process_id, job, config)
# this is the setup process, do not do process intensive stuff here, just variable setup and
# checking requirements. This is called before the run() function
# no loading models or anything like that, it is just for setting up the process
# all of your process intensive stuff should be done in the run() function
# config will have everything from the process item in the config file
# convince methods exist on BaseProcess to get config values
# if required is set to true and the value is not found it will throw an error
# you can pass a default value to get_conf() as well if it was not in the config file
# as well as a type to cast the value to
self.save_path = self.get_conf('save_path', required=True)
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
# build models to merge list
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
# this way you can add methods to it and it is easier to read and code. There are a lot of
# inbuilt config classes located in toolkit.config_modules as well
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
# setup is complete. Don't load anything else here, just setup variables and stuff
# this is the entire run process be sure to call super().run() first
def run(self):
# always call first
super().run()
print(f"Running process: {self.__class__.__name__}")
# let's adjust our weights first to normalize them so the total is 1.0
total_weight = sum([model.weight for model in self.models_to_merge])
weight_adjust = 1.0 / total_weight
for model in self.models_to_merge:
model.weight *= weight_adjust
output_model: StableDiffusion = None
# let's do the merge, it is a good idea to use tqdm to show progress
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
# setup model class with our helper class
sd_model = StableDiffusion(
device=self.device,
model_config=model_config,
dtype="float32"
)
# load the model
sd_model.load_model()
# adjust the weight of the text encoder
if isinstance(sd_model.text_encoder, list):
# sdxl model
for text_encoder in sd_model.text_encoder:
for key, value in text_encoder.state_dict().items():
value *= model_config.weight
else:
# normal model
for key, value in sd_model.text_encoder.state_dict().items():
value *= model_config.weight
# adjust the weights of the unet
for key, value in sd_model.unet.state_dict().items():
value *= model_config.weight
if output_model is None:
# use this one as the base
output_model = sd_model
else:
# merge the models
# text encoder
if isinstance(output_model.text_encoder, list):
# sdxl model
for i, text_encoder in enumerate(output_model.text_encoder):
for key, value in text_encoder.state_dict().items():
value += sd_model.text_encoder[i].state_dict()[key]
else:
# normal model
for key, value in output_model.text_encoder.state_dict().items():
value += sd_model.text_encoder.state_dict()[key]
# unet
for key, value in output_model.unet.state_dict().items():
value += sd_model.unet.state_dict()[key]
# remove the model to free memory
del sd_model
flush()
# merge loop is done, let's save the model
print(f"Saving merged model to {self.save_path}")
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
print(f"Saved merged model to {self.save_path}")
# do cleanup here
del output_model
flush()

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# We make a subclass of Extension
class ExampleMergeExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "example_merge_extension"
# name is the name of the extension for printing
name = "Example Merge Extension"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ExampleMergeModels import ExampleMergeModels
return ExampleMergeModels
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
ExampleMergeExtension
]

View File

@@ -0,0 +1,48 @@
---
# Always include at least one example config file to show how to use your extension.
# use plenty of comments so users know how to use it and what everything does
# all extensions will use this job name
job: extension
config:
name: 'my_awesome_merge'
process:
# Put your example processes here. This will be passed
# to your extension process in the config argument.
# the type MUST match your extension uid
- type: "example_merge_extension"
# save path for the merged model
save_path: "output/merge/[name].safetensors"
# save type
dtype: fp16
# device to run it on
device: cuda:0
# input models can only be SD1.x and SD2.x models for this example (currently)
models_to_merge:
# weights are relative, total weights will be normalized
# for example. If you have 2 models with weight 1.0, they will
# both be weighted 0.5. If you have 1 model with weight 1.0 and
# another with weight 2.0, the first will be weighted 1/3 and the
# second will be weighted 2/3
- name_or_path: "input/model1.safetensors"
weight: 1.0
- name_or_path: "input/model2.safetensors"
weight: 1.0
- name_or_path: "input/model3.safetensors"
weight: 0.3
- name_or_path: "input/model4.safetensors"
weight: 1.0
# you can put any information you want here, and it will be saved in the model
# the below is an example. I recommend doing trigger words at a minimum
# in the metadata. The software will include this plus some other information
meta:
name: "[name]" # [name] gets replaced with the name above
description: A short description of your model
version: '0.1'
creator:
name: Your Name
email: your@email.com
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.