WIP porting to kohya-sdxl. So much to do.

This commit is contained in:
Jaret Burkett
2023-08-15 17:07:34 -06:00
parent 55a5fcc7d9
commit 1d2523b978
5 changed files with 328 additions and 58 deletions

View File

@@ -15,4 +15,4 @@ accelerate
toml
albumentations
pydantic
omegaconf
omegaconf

View File

@@ -76,6 +76,7 @@ class ModelConfig:
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path: str = kwargs.get('vae_path', None)
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')

130
toolkit/model_util_sdxl.py Normal file
View File

@@ -0,0 +1,130 @@
import torch
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from transformers import CLIPTextModelWithProjection, CLIPTextConfig, CLIPTextModel
from library import model_util, sdxl_original_unet
from library.sdxl_model_util import convert_sdxl_text_encoder_2_checkpoint
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
# model_version is reserved for future use
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device=map_location)
epoch = None
global_step = None
else:
checkpoint = torch.load(ckpt_path, map_location=map_location)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
epoch = checkpoint.get("epoch", 0)
global_step = checkpoint.get("global_step", 0)
else:
state_dict = checkpoint
epoch = 0
global_step = 0
checkpoint = None
# U-Net
print("building U-Net")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = unet.load_state_dict(unet_sd)
print("U-Net: ", info)
del unet_sd
# Text Encoders
print("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
if k.endswith("text_model.embeddings.position_ids"):
# skip position_ids
state_dict.pop(k)
elif k.startswith("conditioner.embedders.0.transformer."):
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
info1 = text_model1.load_state_dict(te1_sd)
print("text encoder 1:", info1)
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
# remove text_model.embeddings.position_ids"
converted_sd.pop("text_model.embeddings.position_ids")
info2 = text_model2.load_state_dict(converted_sd)
print("text encoder 2:", info2)
# prepare vae
print("building VAE")
vae_config = model_util.create_vae_diffusers_config()
vae = AutoencoderKL(**vae_config) # .to(device)
print("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print("VAE:", info)
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info

View File

@@ -2,9 +2,76 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from typing import TYPE_CHECKING
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
if TYPE_CHECKING:
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from diffusers.schedulers import KarrasDiffusionSchedulers
class FakeWatermarker(StableDiffusionXLWatermarker):
def __init__(self):
super().__init__()
def apply_watermark(self, image):
return image
class HackedStableDiffusionXLPipeline(StableDiffusionXLPipeline):
def __init__(
self,
vae: 'AutoencoderKL',
text_encoder: 'CLIPTextModel',
text_encoder_2: 'CLIPTextModelWithProjection',
tokenizer: 'CLIPTokenizer',
tokenizer_2: 'CLIPTokenizer',
unet: 'UNet2DConditionModel',
scheduler: 'KarrasDiffusionSchedulers',
force_zeros_for_empty_prompt: bool = True,
add_watermarker: bool = False,
):
# call parents parent super skipping parent
super(StableDiffusionXLPipeline, self).__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 1024
self.watermark = FakeWatermarker()
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
# passed_add_embed_dim = (
# self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
# )
# expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
#
# if expected_add_embed_dim != passed_add_embed_dim:
# raise ValueError(
# f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
# )
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# def __init__(self, *args, **kwargs):

View File

@@ -6,25 +6,29 @@ import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file
from torch.amp import autocast
from tqdm import tqdm
from torchvision.transforms import Resize
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict
from library.sdxl_train_util import _load_target_model as load_sdxl_target_model
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.model_util_sdxl import load_models_from_sdxl_checkpoint
from toolkit.paths import REPOS_ROOT, MODELS_PATH
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
sys.path.append(os.path.join(REPOS_ROOT, 'sd-scripts'))
from leco import train_util
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2
from library import model_util, train_util as kohya_train_util, sdxl_train_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl, DIFFUSERS_SDXL_UNET_CONFIG
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
HackedStableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import diffusers
@@ -76,6 +80,8 @@ class PromptEmbeds:
return self
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
# if is type checking
if typing.TYPE_CHECKING:
from diffusers import \
@@ -83,7 +89,6 @@ if typing.TYPE_CHECKING:
AutoencoderKL, \
UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
class StableDiffusion:
@@ -104,7 +109,7 @@ class StableDiffusion:
device,
model_config: ModelConfig,
dtype='fp16',
custom_pipeline=None
custom_pipeline=None,
):
self.custom_pipeline = custom_pipeline
self.device = device
@@ -116,7 +121,7 @@ class StableDiffusion:
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
self.ckpt_info = None
self.is_loaded = False
# to hold network if there is one
@@ -151,35 +156,101 @@ class StableDiffusion:
model_path = get_model_path_from_url(self.model_config.name_or_path)
if self.model_config.is_xl:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionXLPipeline
# load from kohya
# (
# load_stable_diffusion_format,
# text_encoder1,
# text_encoder2,
# vae,
# unet,
# logit_scale,
# ckpt_info,
# ) = load_sdxl_target_model(
# model_path,
# self.model_config.vae_path if self.model_config.vae_path is not None else model_path,
# 'sdxl',
# self.dtype,
# self.device,
# )
# see if path exists
if not os.path.exists(model_path):
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
else:
pipe = pipln.from_single_file(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
(
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = load_models_from_sdxl_checkpoint(
'sdxl',
model_path,
self.device
)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
class Config:
def __init__(self):
# add all items from DIFFUSERS_SDXL_UNET_CONFIG as attributes
for k, v in DIFFUSERS_SDXL_UNET_CONFIG.items():
setattr(self, k, v)
# add diffusers stuff
unet.config = Config()
if self.model_config.vae_path is not None:
vae = model_util.load_vae(self.model_config.vae_path, self.dtype)
print("additional VAE loaded")
text_encoder1, text_encoder2, unet = kohya_train_util.transform_models_if_DDP(
[text_encoder1, text_encoder2, unet]
)
class Args:
tokenizer_cache_dir = os.path.join(MODELS_PATH, 'clip_cache')
args = Args()
os.makedirs(args.tokenizer_cache_dir, exist_ok=True)
self.tokenizer = sdxl_train_util.load_tokenizers(args)
# tokenizer1 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
# tokenizer2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2")
gc.collect()
torch.cuda.empty_cache()
self.vae = vae
self.unet = unet.to(self.device_torch, dtype=dtype)
self.text_encoder = [text_encoder1, text_encoder2]
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
# if self.custom_pipeline is not None:
# pipln = self.custom_pipeline
# else:
# pipln = CustomStableDiffusionXLPipeline
#
# # see if path exists
# if not os.path.exists(model_path):
# # try to load with default diffusers
# pipe = pipln.from_pretrained(
# model_path,
# dtype=dtype,
# scheduler_type='ddpm',
# device=self.device_torch,
# ).to(self.device_torch)
# else:
# pipe = pipln.from_single_file(
# model_path,
# dtype=dtype,
# scheduler_type='ddpm',
# device=self.device_torch,
# ).to(self.device_torch)
#
# text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
# tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
# for text_encoder in text_encoders:
# text_encoder.to(self.device_torch, dtype=dtype)
# text_encoder.requires_grad_(False)
# text_encoder.eval()
# text_encoder = text_encoders
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -215,22 +286,22 @@ class StableDiffusion:
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.unet = pipe.unet
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.noise_scheduler = scheduler
self.vae.eval()
self.vae.requires_grad_(False)
self.unet.to(self.device_torch, dtype=dtype)
self.unet.requires_grad_(False)
self.unet.eval()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
self.is_loaded = True
def generate_images(self, image_configs: List[GenerateImageConfig]):
@@ -265,7 +336,7 @@ class StableDiffusion:
# TODO add clip skip
if self.is_xl:
pipeline = StableDiffusionXLPipeline(
pipeline = HackedStableDiffusionXLPipeline(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
@@ -310,17 +381,18 @@ class StableDiffusion:
torch.cuda.manual_seed(gen_config.seed)
if self.is_xl:
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
).images[0]
with autocast('cuda'):
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
).images[0]
else:
img = pipeline(
prompt=gen_config.prompt,