fixed issues with converting and saving models. Cleaned keys. Improved testing for cycle load saving.

This commit is contained in:
Jaret Burkett
2023-08-29 12:31:19 -06:00
parent 714854ee86
commit 14ff51ceb4
9 changed files with 784 additions and 1568 deletions

View File

@@ -1,4 +1,5 @@
import gc
import json
import typing
from typing import Union, List, Tuple
import sys
@@ -15,7 +16,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
@@ -37,6 +38,14 @@ SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te1"
SD_PREFIX_TEXT_ENCODER2 = "te2"
# prefixed diffusers keys
DO_NOT_TRAIN_WEIGHTS = [
"unet_time_embedding.linear_1.bias",
"unet_time_embedding.linear_1.weight",
"unet_time_embedding.linear_2.bias",
"unet_time_embedding.linear_2.weight",
]
class BlankNetwork:
@@ -63,10 +72,6 @@ def flush():
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
# if is type checking
if typing.TYPE_CHECKING:
from diffusers import \
@@ -734,3 +739,49 @@ class StableDiffusion:
save_dtype=save_dtype,
sd_version=version_string,
)
def prepare_optimizer_params(
self,
unet=False,
text_encoder=False,
text_encoder_lr=None,
unet_lr=None,
default_lr=1e-6,
):
# todo maybe only get locon ones?
# not all items are saved, to make it match, we need to match out save mappings
# and not train anything not mapped. Also add learning rate
version = 'sd1'
if self.is_xl:
version = 'sdxl'
if self.is_v2:
version = 'sd2'
mapping_filename = f"stable_diffusion_{version}.json"
mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename)
with open(mapping_path, 'r') as f:
mapping = json.load(f)
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
trainable_parameters = []
if unet:
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=False)
unet_lr = unet_lr if unet_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
params.append(state_dict[diffusers_key])
param_data = {"params": params, "lr": unet_lr}
trainable_parameters.append(param_data)
if text_encoder:
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=text_encoder)
text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
params.append(state_dict[diffusers_key])
param_data = {"params": params, "lr": text_encoder_lr}
trainable_parameters.append(param_data)
return trainable_parameters