mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-20 22:39:03 +00:00
Added extensions and an example extension that merges models
This commit is contained in:
129
extensions/example/ExampleMergeModels.py
Normal file
129
extensions/example/ExampleMergeModels.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type check imports. Prevents circular imports
|
||||
if TYPE_CHECKING:
|
||||
from jobs import ExtensionJob
|
||||
|
||||
|
||||
# extend standard config classes to add weight
|
||||
class ModelInputConfig(ModelConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.weight = kwargs.get('weight', 1.0)
|
||||
# overwrite default dtype unless user specifies otherwise
|
||||
# float 32 will give up better precision on the merging functions
|
||||
self.dtype: str = kwargs.get('dtype', 'float32')
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
# this is our main class process
|
||||
class ExampleMergeModels(BaseExtensionProcess):
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: 'ExtensionJob',
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
||||
# checking requirements. This is called before the run() function
|
||||
# no loading models or anything like that, it is just for setting up the process
|
||||
# all of your process intensive stuff should be done in the run() function
|
||||
# config will have everything from the process item in the config file
|
||||
|
||||
# convince methods exist on BaseProcess to get config values
|
||||
# if required is set to true and the value is not found it will throw an error
|
||||
# you can pass a default value to get_conf() as well if it was not in the config file
|
||||
# as well as a type to cast the value to
|
||||
self.save_path = self.get_conf('save_path', required=True)
|
||||
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
||||
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
||||
|
||||
# build models to merge list
|
||||
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
||||
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
||||
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
||||
# inbuilt config classes located in toolkit.config_modules as well
|
||||
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
||||
# setup is complete. Don't load anything else here, just setup variables and stuff
|
||||
|
||||
# this is the entire run process be sure to call super().run() first
|
||||
def run(self):
|
||||
# always call first
|
||||
super().run()
|
||||
print(f"Running process: {self.__class__.__name__}")
|
||||
|
||||
# let's adjust our weights first to normalize them so the total is 1.0
|
||||
total_weight = sum([model.weight for model in self.models_to_merge])
|
||||
weight_adjust = 1.0 / total_weight
|
||||
for model in self.models_to_merge:
|
||||
model.weight *= weight_adjust
|
||||
|
||||
output_model: StableDiffusion = None
|
||||
# let's do the merge, it is a good idea to use tqdm to show progress
|
||||
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
||||
# setup model class with our helper class
|
||||
sd_model = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config,
|
||||
dtype="float32"
|
||||
)
|
||||
# load the model
|
||||
sd_model.load_model()
|
||||
|
||||
# adjust the weight of the text encoder
|
||||
if isinstance(sd_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for text_encoder in sd_model.text_encoder:
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
else:
|
||||
# normal model
|
||||
for key, value in sd_model.text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
# adjust the weights of the unet
|
||||
for key, value in sd_model.unet.state_dict().items():
|
||||
value *= model_config.weight
|
||||
|
||||
if output_model is None:
|
||||
# use this one as the base
|
||||
output_model = sd_model
|
||||
else:
|
||||
# merge the models
|
||||
# text encoder
|
||||
if isinstance(output_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for i, text_encoder in enumerate(output_model.text_encoder):
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder[i].state_dict()[key]
|
||||
else:
|
||||
# normal model
|
||||
for key, value in output_model.text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder.state_dict()[key]
|
||||
# unet
|
||||
for key, value in output_model.unet.state_dict().items():
|
||||
value += sd_model.unet.state_dict()[key]
|
||||
|
||||
# remove the model to free memory
|
||||
del sd_model
|
||||
flush()
|
||||
|
||||
# merge loop is done, let's save the model
|
||||
print(f"Saving merged model to {self.save_path}")
|
||||
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
||||
print(f"Saved merged model to {self.save_path}")
|
||||
# do cleanup here
|
||||
del output_model
|
||||
flush()
|
||||
25
extensions/example/__init__.py
Normal file
25
extensions/example/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# We make a subclass of Extension
|
||||
class ExampleMergeExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "example_merge_extension"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Example Merge Extension"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ExampleMergeModels import ExampleMergeModels
|
||||
return ExampleMergeModels
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
ExampleMergeExtension
|
||||
]
|
||||
48
extensions/example/config/config.example.yaml
Normal file
48
extensions/example/config/config.example.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
# Always include at least one example config file to show how to use your extension.
|
||||
# use plenty of comments so users know how to use it and what everything does
|
||||
|
||||
# all extensions will use this job name
|
||||
job: extension
|
||||
config:
|
||||
name: 'my_awesome_merge'
|
||||
process:
|
||||
# Put your example processes here. This will be passed
|
||||
# to your extension process in the config argument.
|
||||
# the type MUST match your extension uid
|
||||
- type: "example_merge_extension"
|
||||
# save path for the merged model
|
||||
save_path: "output/merge/[name].safetensors"
|
||||
# save type
|
||||
dtype: fp16
|
||||
# device to run it on
|
||||
device: cuda:0
|
||||
# input models can only be SD1.x and SD2.x models for this example (currently)
|
||||
models_to_merge:
|
||||
# weights are relative, total weights will be normalized
|
||||
# for example. If you have 2 models with weight 1.0, they will
|
||||
# both be weighted 0.5. If you have 1 model with weight 1.0 and
|
||||
# another with weight 2.0, the first will be weighted 1/3 and the
|
||||
# second will be weighted 2/3
|
||||
- name_or_path: "input/model1.safetensors"
|
||||
weight: 1.0
|
||||
- name_or_path: "input/model2.safetensors"
|
||||
weight: 1.0
|
||||
- name_or_path: "input/model3.safetensors"
|
||||
weight: 0.3
|
||||
- name_or_path: "input/model4.safetensors"
|
||||
weight: 1.0
|
||||
|
||||
|
||||
# you can put any information you want here, and it will be saved in the model
|
||||
# the below is an example. I recommend doing trigger words at a minimum
|
||||
# in the metadata. The software will include this plus some other information
|
||||
meta:
|
||||
name: "[name]" # [name] gets replaced with the name above
|
||||
description: A short description of your model
|
||||
version: '0.1'
|
||||
creator:
|
||||
name: Your Name
|
||||
email: your@email.com
|
||||
website: https://yourwebsite.com
|
||||
any: All meta data above is arbitrary, it can be whatever you want.
|
||||
Reference in New Issue
Block a user