mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Fixed some mismatched weights by adjusting tolerance. The mismatch ironically made the models better lol
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import gc
|
||||
import json
|
||||
import typing
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Iterator
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
|
||||
@@ -31,12 +32,12 @@ import diffusers
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
|
||||
VAE_PREFIX_UNET = "vae"
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te1"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
||||
|
||||
# prefixed diffusers keys
|
||||
DO_NOT_TRAIN_WEIGHTS = [
|
||||
@@ -184,6 +185,21 @@ class StableDiffusion:
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
|
||||
if self.model_config.experimental_xl:
|
||||
print("Experimental XL mode enabled")
|
||||
print("Loading and injecting alt weights")
|
||||
# load the mismatched weight and force it in
|
||||
raw_state_dict = load_file(model_path)
|
||||
replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone()
|
||||
del raw_state_dict
|
||||
# get state dict for for 2nd text encoder
|
||||
te1_state_dict = text_encoders[1].state_dict()
|
||||
# replace weight with mismatched weight
|
||||
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
print("Injecting alt weights")
|
||||
|
||||
else:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -707,7 +723,7 @@ class StableDiffusion:
|
||||
state_dict = OrderedDict()
|
||||
if vae:
|
||||
for k, v in self.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{VAE_PREFIX_UNET}") else f"{VAE_PREFIX_UNET}_{k}"
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
state_dict[new_key] = v
|
||||
if text_encoder:
|
||||
if isinstance(self.text_encoder, list):
|
||||
@@ -726,6 +742,35 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||
named_params[name] = param
|
||||
if text_encoder:
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# convert to state dict keys, jsut replace . with _ on keys
|
||||
if state_dict_keys:
|
||||
new_named_params = OrderedDict()
|
||||
for k, v in named_params.items():
|
||||
# replace only the first . with an _
|
||||
new_key = k.replace('.', '_', 1)
|
||||
new_named_params[new_key] = v
|
||||
named_params = new_named_params
|
||||
|
||||
return named_params
|
||||
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
@@ -764,24 +809,31 @@ class StableDiffusion:
|
||||
|
||||
trainable_parameters = []
|
||||
|
||||
# we use state dict to find params
|
||||
|
||||
if unet:
|
||||
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=False)
|
||||
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
params.append(state_dict[diffusers_key])
|
||||
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
if named_params[diffusers_key].requires_grad:
|
||||
params.append(named_params[diffusers_key])
|
||||
param_data = {"params": params, "lr": unet_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
print(f"Found {len(params)} trainable parameter in unet")
|
||||
|
||||
if text_encoder:
|
||||
state_dict = self.state_dict(vae=False, unet=unet, text_encoder=text_encoder)
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True)
|
||||
text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
if diffusers_key in state_dict and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
params.append(state_dict[diffusers_key])
|
||||
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
if named_params[diffusers_key].requires_grad:
|
||||
params.append(named_params[diffusers_key])
|
||||
param_data = {"params": params, "lr": text_encoder_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
Reference in New Issue
Block a user