mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added extensions and an example extension that merges models
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -171,3 +171,5 @@ cython_debug/
|
||||
!/config/_PUT_YOUR_CONFIGS_HERE).txt
|
||||
/output/*
|
||||
!/output/.gitkeep
|
||||
/extensions/*
|
||||
!/extensions/example
|
||||
23
README.md
23
README.md
@@ -126,6 +126,23 @@ I will post an better tutorial soon.
|
||||
|
||||
---
|
||||
|
||||
## Extensions!!
|
||||
|
||||
You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
|
||||
available to them. I will probably use this as the primary development method going
|
||||
forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
|
||||
of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
|
||||
folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
|
||||
enough to get you started. To make an extension, just copy that example and replace all the things you need to.
|
||||
|
||||
|
||||
### Model Merger - Example Extension
|
||||
It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
|
||||
as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
|
||||
mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
|
||||
example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
|
||||
and use it like any other config file.
|
||||
|
||||
## WIP Tools
|
||||
|
||||
|
||||
@@ -153,6 +170,12 @@ Just went in and out. It is much worse on smaller faces than shown here.
|
||||
|
||||
## Change Log
|
||||
|
||||
#### 2021-10-20
|
||||
- Windows support bug fixes
|
||||
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
|
||||
check out the example in the `extensions` folder. Read more about that above.
|
||||
- Model Merging, provided via the example extension.
|
||||
|
||||
#### 2021-08-03
|
||||
Another big refactor to make SD more modular.
|
||||
|
||||
|
||||
129
extensions/example/ExampleMergeModels.py
Normal file
129
extensions/example/ExampleMergeModels.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from tqdm import tqdm
|
||||
|
||||
# Type check imports. Prevents circular imports
|
||||
if TYPE_CHECKING:
|
||||
from jobs import ExtensionJob
|
||||
|
||||
|
||||
# extend standard config classes to add weight
|
||||
class ModelInputConfig(ModelConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.weight = kwargs.get('weight', 1.0)
|
||||
# overwrite default dtype unless user specifies otherwise
|
||||
# float 32 will give up better precision on the merging functions
|
||||
self.dtype: str = kwargs.get('dtype', 'float32')
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
# this is our main class process
|
||||
class ExampleMergeModels(BaseExtensionProcess):
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: 'ExtensionJob',
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
# this is the setup process, do not do process intensive stuff here, just variable setup and
|
||||
# checking requirements. This is called before the run() function
|
||||
# no loading models or anything like that, it is just for setting up the process
|
||||
# all of your process intensive stuff should be done in the run() function
|
||||
# config will have everything from the process item in the config file
|
||||
|
||||
# convince methods exist on BaseProcess to get config values
|
||||
# if required is set to true and the value is not found it will throw an error
|
||||
# you can pass a default value to get_conf() as well if it was not in the config file
|
||||
# as well as a type to cast the value to
|
||||
self.save_path = self.get_conf('save_path', required=True)
|
||||
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
|
||||
self.device = self.get_conf('device', default='cpu', as_type=torch.device)
|
||||
|
||||
# build models to merge list
|
||||
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
|
||||
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
|
||||
# this way you can add methods to it and it is easier to read and code. There are a lot of
|
||||
# inbuilt config classes located in toolkit.config_modules as well
|
||||
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
|
||||
# setup is complete. Don't load anything else here, just setup variables and stuff
|
||||
|
||||
# this is the entire run process be sure to call super().run() first
|
||||
def run(self):
|
||||
# always call first
|
||||
super().run()
|
||||
print(f"Running process: {self.__class__.__name__}")
|
||||
|
||||
# let's adjust our weights first to normalize them so the total is 1.0
|
||||
total_weight = sum([model.weight for model in self.models_to_merge])
|
||||
weight_adjust = 1.0 / total_weight
|
||||
for model in self.models_to_merge:
|
||||
model.weight *= weight_adjust
|
||||
|
||||
output_model: StableDiffusion = None
|
||||
# let's do the merge, it is a good idea to use tqdm to show progress
|
||||
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
|
||||
# setup model class with our helper class
|
||||
sd_model = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config,
|
||||
dtype="float32"
|
||||
)
|
||||
# load the model
|
||||
sd_model.load_model()
|
||||
|
||||
# adjust the weight of the text encoder
|
||||
if isinstance(sd_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for text_encoder in sd_model.text_encoder:
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
else:
|
||||
# normal model
|
||||
for key, value in sd_model.text_encoder.state_dict().items():
|
||||
value *= model_config.weight
|
||||
# adjust the weights of the unet
|
||||
for key, value in sd_model.unet.state_dict().items():
|
||||
value *= model_config.weight
|
||||
|
||||
if output_model is None:
|
||||
# use this one as the base
|
||||
output_model = sd_model
|
||||
else:
|
||||
# merge the models
|
||||
# text encoder
|
||||
if isinstance(output_model.text_encoder, list):
|
||||
# sdxl model
|
||||
for i, text_encoder in enumerate(output_model.text_encoder):
|
||||
for key, value in text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder[i].state_dict()[key]
|
||||
else:
|
||||
# normal model
|
||||
for key, value in output_model.text_encoder.state_dict().items():
|
||||
value += sd_model.text_encoder.state_dict()[key]
|
||||
# unet
|
||||
for key, value in output_model.unet.state_dict().items():
|
||||
value += sd_model.unet.state_dict()[key]
|
||||
|
||||
# remove the model to free memory
|
||||
del sd_model
|
||||
flush()
|
||||
|
||||
# merge loop is done, let's save the model
|
||||
print(f"Saving merged model to {self.save_path}")
|
||||
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
|
||||
print(f"Saved merged model to {self.save_path}")
|
||||
# do cleanup here
|
||||
del output_model
|
||||
flush()
|
||||
25
extensions/example/__init__.py
Normal file
25
extensions/example/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# We make a subclass of Extension
|
||||
class ExampleMergeExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "example_merge_extension"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Example Merge Extension"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ExampleMergeModels import ExampleMergeModels
|
||||
return ExampleMergeModels
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
ExampleMergeExtension
|
||||
]
|
||||
48
extensions/example/config/config.example.yaml
Normal file
48
extensions/example/config/config.example.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
# Always include at least one example config file to show how to use your extension.
|
||||
# use plenty of comments so users know how to use it and what everything does
|
||||
|
||||
# all extensions will use this job name
|
||||
job: extension
|
||||
config:
|
||||
name: 'my_awesome_merge'
|
||||
process:
|
||||
# Put your example processes here. This will be passed
|
||||
# to your extension process in the config argument.
|
||||
# the type MUST match your extension uid
|
||||
- type: "example_merge_extension"
|
||||
# save path for the merged model
|
||||
save_path: "output/merge/[name].safetensors"
|
||||
# save type
|
||||
dtype: fp16
|
||||
# device to run it on
|
||||
device: cuda:0
|
||||
# input models can only be SD1.x and SD2.x models for this example (currently)
|
||||
models_to_merge:
|
||||
# weights are relative, total weights will be normalized
|
||||
# for example. If you have 2 models with weight 1.0, they will
|
||||
# both be weighted 0.5. If you have 1 model with weight 1.0 and
|
||||
# another with weight 2.0, the first will be weighted 1/3 and the
|
||||
# second will be weighted 2/3
|
||||
- name_or_path: "input/model1.safetensors"
|
||||
weight: 1.0
|
||||
- name_or_path: "input/model2.safetensors"
|
||||
weight: 1.0
|
||||
- name_or_path: "input/model3.safetensors"
|
||||
weight: 0.3
|
||||
- name_or_path: "input/model4.safetensors"
|
||||
weight: 1.0
|
||||
|
||||
|
||||
# you can put any information you want here, and it will be saved in the model
|
||||
# the below is an example. I recommend doing trigger words at a minimum
|
||||
# in the metadata. The software will include this plus some other information
|
||||
meta:
|
||||
name: "[name]" # [name] gets replaced with the name above
|
||||
description: A short description of your model
|
||||
version: '0.1'
|
||||
creator:
|
||||
name: Your Name
|
||||
email: your@email.com
|
||||
website: https://yourwebsite.com
|
||||
any: All meta data above is arbitrary, it can be whatever you want.
|
||||
2
info.py
2
info.py
@@ -3,6 +3,6 @@ from collections import OrderedDict
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.0.2"
|
||||
v["version"] = "0.0.3"
|
||||
|
||||
software_meta = v
|
||||
|
||||
@@ -60,7 +60,11 @@ class BaseJob:
|
||||
|
||||
# check if dict key is process type
|
||||
if process['type'] in process_dict:
|
||||
if isinstance(process_dict[process['type']], str):
|
||||
ProcessClass = getattr(module, process_dict[process['type']])
|
||||
else:
|
||||
# it is the class
|
||||
ProcessClass = process_dict[process['type']]
|
||||
self.process.append(ProcessClass(i, self, process))
|
||||
else:
|
||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||
|
||||
21
jobs/ExtensionJob.py
Normal file
21
jobs/ExtensionJob.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections import OrderedDict
|
||||
from jobs import BaseJob
|
||||
from toolkit.extension import get_all_extensions_process_dict
|
||||
|
||||
|
||||
class ExtensionJob(BaseJob):
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
self.process_dict = get_all_extensions_process_dict()
|
||||
self.load_processes(self.process_dict)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
@@ -4,3 +4,4 @@ from .TrainJob import TrainJob
|
||||
from .MergeJob import MergeJob
|
||||
from .ModJob import ModJob
|
||||
from .GenerateJob import GenerateJob
|
||||
from .ExtensionJob import ExtensionJob
|
||||
|
||||
20
jobs/process/BaseExtensionProcess.py
Normal file
20
jobs/process/BaseExtensionProcess.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
|
||||
class BaseExtensionProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -11,3 +11,4 @@ from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
from .BaseExtensionProcess import BaseExtensionProcess
|
||||
|
||||
56
toolkit/extension.py
Normal file
56
toolkit/extension.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import List
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
|
||||
|
||||
class Extension(object):
|
||||
"""Base class for extensions.
|
||||
|
||||
Extensions are registered with the ExtensionManager, which is
|
||||
responsible for calling the extension's load() and unload()
|
||||
methods at the appropriate times.
|
||||
|
||||
"""
|
||||
|
||||
name: str = None
|
||||
uid: str = None
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# extend in subclass
|
||||
pass
|
||||
|
||||
|
||||
def get_all_extensions() -> List[Extension]:
|
||||
# Get the path of the "extensions" directory
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions")
|
||||
|
||||
# This will hold the classes from all extension modules
|
||||
all_extension_classes: List[Extension] = []
|
||||
|
||||
# Iterate over all directories (i.e., packages) in the "extensions" directory
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"extensions.{name}")
|
||||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
|
||||
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
|
||||
# Check if the value is a list
|
||||
if isinstance(extensions, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_extension_classes.extend(extensions)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
|
||||
return all_extension_classes
|
||||
|
||||
|
||||
def get_all_extensions_process_dict():
|
||||
all_extensions = get_all_extensions()
|
||||
process_dict = {}
|
||||
for extension in all_extensions:
|
||||
process_dict[extension.uid] = extension.get_process()
|
||||
return process_dict
|
||||
@@ -19,6 +19,9 @@ def get_job(config_path, name=None):
|
||||
if job == 'generate':
|
||||
from jobs import GenerateJob
|
||||
return GenerateJob(config)
|
||||
if job == 'extension':
|
||||
from jobs import ExtensionJob
|
||||
return ExtensionJob(config)
|
||||
|
||||
# elif job == 'train':
|
||||
# from jobs import TrainJob
|
||||
|
||||
@@ -8,7 +8,10 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
|
||||
@@ -161,6 +164,7 @@ class StableDiffusion:
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
).to(self.device_torch)
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
@@ -468,9 +472,6 @@ class StableDiffusion:
|
||||
)
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
@@ -479,6 +480,8 @@ class StableDiffusion:
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||
state_dict[key] = v
|
||||
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
# Convert the UNet model
|
||||
update_sd("model.diffusion_model.", self.unet.state_dict())
|
||||
|
||||
@@ -488,19 +491,25 @@ class StableDiffusion:
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
|
||||
else:
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
|
||||
update_sd("model.diffusion_model.", unet_state_dict)
|
||||
|
||||
# Convert the text encoder model
|
||||
if self.is_v2:
|
||||
make_dummy = True
|
||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
|
||||
update_sd("cond_stage_model.model.", text_enc_dict)
|
||||
else:
|
||||
text_enc_dict = self.text_encoder.state_dict()
|
||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
|
||||
# Convert the VAE
|
||||
if self.vae is not None:
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {"state_dict": state_dict}
|
||||
|
||||
if model_util.is_safetensors(output_file):
|
||||
save_file(state_dict, output_file)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file, meta)
|
||||
|
||||
return key_count
|
||||
else:
|
||||
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
|
||||
Reference in New Issue
Block a user