mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added extensions and an example extension that merges models
This commit is contained in:
56
toolkit/extension.py
Normal file
56
toolkit/extension.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import List
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
|
||||
|
||||
class Extension(object):
|
||||
"""Base class for extensions.
|
||||
|
||||
Extensions are registered with the ExtensionManager, which is
|
||||
responsible for calling the extension's load() and unload()
|
||||
methods at the appropriate times.
|
||||
|
||||
"""
|
||||
|
||||
name: str = None
|
||||
uid: str = None
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# extend in subclass
|
||||
pass
|
||||
|
||||
|
||||
def get_all_extensions() -> List[Extension]:
|
||||
# Get the path of the "extensions" directory
|
||||
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions")
|
||||
|
||||
# This will hold the classes from all extension modules
|
||||
all_extension_classes: List[Extension] = []
|
||||
|
||||
# Iterate over all directories (i.e., packages) in the "extensions" directory
|
||||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f"extensions.{name}")
|
||||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
|
||||
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
|
||||
# Check if the value is a list
|
||||
if isinstance(extensions, list):
|
||||
# Iterate over the list and add the classes to the main list
|
||||
all_extension_classes.extend(extensions)
|
||||
except ImportError as e:
|
||||
print(f"Failed to import the {name} module. Error: {str(e)}")
|
||||
|
||||
return all_extension_classes
|
||||
|
||||
|
||||
def get_all_extensions_process_dict():
|
||||
all_extensions = get_all_extensions()
|
||||
process_dict = {}
|
||||
for extension in all_extensions:
|
||||
process_dict[extension.uid] = extension.get_process()
|
||||
return process_dict
|
||||
@@ -19,6 +19,9 @@ def get_job(config_path, name=None):
|
||||
if job == 'generate':
|
||||
from jobs import GenerateJob
|
||||
return GenerateJob(config)
|
||||
if job == 'extension':
|
||||
from jobs import ExtensionJob
|
||||
return ExtensionJob(config)
|
||||
|
||||
# elif job == 'train':
|
||||
# from jobs import TrainJob
|
||||
|
||||
@@ -8,7 +8,10 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
|
||||
@@ -161,6 +164,7 @@ class StableDiffusion:
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
).to(self.device_torch)
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
@@ -468,17 +472,16 @@ class StableDiffusion:
|
||||
)
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||
state_dict[key] = v
|
||||
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||
state_dict[key] = v
|
||||
|
||||
# Convert the UNet model
|
||||
update_sd("model.diffusion_model.", self.unet.state_dict())
|
||||
|
||||
@@ -488,19 +491,25 @@ class StableDiffusion:
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
|
||||
else:
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
|
||||
update_sd("model.diffusion_model.", unet_state_dict)
|
||||
|
||||
# Convert the text encoder model
|
||||
if self.is_v2:
|
||||
make_dummy = True
|
||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
|
||||
update_sd("cond_stage_model.model.", text_enc_dict)
|
||||
else:
|
||||
text_enc_dict = self.text_encoder.state_dict()
|
||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
|
||||
# Convert the VAE
|
||||
if self.vae is not None:
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {"state_dict": state_dict}
|
||||
|
||||
if model_util.is_safetensors(output_file):
|
||||
save_file(state_dict, output_file)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file, meta)
|
||||
|
||||
return key_count
|
||||
else:
|
||||
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
|
||||
Reference in New Issue
Block a user