mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
fixed issues with converting and saving models. Cleaned keys. Improved testing for cycle load saving.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import gc
|
||||
import json
|
||||
import typing
|
||||
from typing import Union, List, Tuple
|
||||
import sys
|
||||
@@ -15,7 +16,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
@@ -37,6 +38,14 @@ SD_PREFIX_TEXT_ENCODER = "te"
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te1"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||
|
||||
# prefixed diffusers keys
|
||||
DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_1.bias",
|
||||
"unet_time_embedding.linear_1.weight",
|
||||
"unet_time_embedding.linear_2.bias",
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
|
||||
@@ -63,10 +72,6 @@ def flush():
|
||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import \
|
||||
@@ -734,3 +739,49 @@ class StableDiffusion:
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
unet=False,
|
||||
text_encoder=False,
|
||||
text_encoder_lr=None,
|
||||
unet_lr=None,
|
||||
default_lr=1e-6,
|
||||
):
|
||||
# todo maybe only get locon ones?
|
||||
# not all items are saved, to make it match, we need to match out save mappings
|
||||
# and not train anything not mapped. Also add learning rate
|
||||
version = 'sd1'
|
||||
if self.is_xl:
|
||||
version = 'sdxl'
|
||||
if self.is_v2:
|
||||
version = 'sd2'
|
||||
mapping_filename = f"stable_diffusion_{version}.json"
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename)
|
||||
with open(mapping_path, 'r') as f:
|
||||
mapping = json.load(f)
|
||||
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
|
||||
|
||||
trainable_parameters = []
|
||||
|
||||
if unet:
|
||||
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=False)
|
||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
params.append(state_dict[diffusers_key])
|
||||
param_data = {"params": params, "lr": unet_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
if text_encoder:
|
||||
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=text_encoder)
|
||||
text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
params.append(state_dict[diffusers_key])
|
||||
param_data = {"params": params, "lr": text_encoder_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
Reference in New Issue
Block a user