From 1d2523b9781185cd6f8c63dd650a9e48a27f81a7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 15 Aug 2023 17:07:34 -0600 Subject: [PATCH] WIP porting to kohya-sdxl. So much to do. --- requirements.txt | 2 +- toolkit/config_modules.py | 1 + toolkit/model_util_sdxl.py | 130 +++++++++++++++++++++ toolkit/pipelines.py | 67 +++++++++++ toolkit/stable_diffusion_model.py | 186 +++++++++++++++++++++--------- 5 files changed, 328 insertions(+), 58 deletions(-) create mode 100644 toolkit/model_util_sdxl.py diff --git a/requirements.txt b/requirements.txt index a11b6211..ef809c80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,4 @@ accelerate toml albumentations pydantic -omegaconf \ No newline at end of file +omegaconf diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f84273d9..fb69bdef 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -76,6 +76,7 @@ class ModelConfig: self.is_xl: bool = kwargs.get('is_xl', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.dtype: str = kwargs.get('dtype', 'float16') + self.vae_path: str = kwargs.get('vae_path', None) if self.name_or_path is None: raise ValueError('name_or_path must be specified') diff --git a/toolkit/model_util_sdxl.py b/toolkit/model_util_sdxl.py new file mode 100644 index 00000000..4a154918 --- /dev/null +++ b/toolkit/model_util_sdxl.py @@ -0,0 +1,130 @@ +import torch +from diffusers import AutoencoderKL +from safetensors.torch import load_file +from transformers import CLIPTextModelWithProjection, CLIPTextConfig, CLIPTextModel + +from library import model_util, sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_text_encoder_2_checkpoint + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): + # model_version is reserved for future use + + # Load the state dict + if model_util.is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, device=map_location) + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + print("building U-Net") + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + + print("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = unet.load_state_dict(unet_sd) + print("U-Net: ", info) + del unet_sd + + # Text Encoders + print("building text encoders") + + # Text Encoder 1 is same to Stability AI's SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + + print("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.endswith("text_model.embeddings.position_ids"): + # skip position_ids + state_dict.pop(k) + elif k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + + + info1 = text_model1.load_state_dict(te1_sd) + print("text encoder 1:", info1) + + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + # remove text_model.embeddings.position_ids" + converted_sd.pop("text_model.embeddings.position_ids") + info2 = text_model2.load_state_dict(converted_sd) + print("text encoder 2:", info2) + + # prepare vae + print("building VAE") + vae_config = model_util.create_vae_diffusers_config() + vae = AutoencoderKL(**vae_config) # .to(device) + + print("loading VAE from checkpoint") + converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("VAE:", info) + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index f772fa1f..c72480c9 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -2,9 +2,76 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable import torch from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from typing import TYPE_CHECKING + +from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if TYPE_CHECKING: + from diffusers import AutoencoderKL, UNet2DConditionModel + from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + from diffusers.schedulers import KarrasDiffusionSchedulers + + +class FakeWatermarker(StableDiffusionXLWatermarker): + def __init__(self): + super().__init__() + + def apply_watermark(self, image): + return image + + +class HackedStableDiffusionXLPipeline(StableDiffusionXLPipeline): + def __init__( + self, + vae: 'AutoencoderKL', + text_encoder: 'CLIPTextModel', + text_encoder_2: 'CLIPTextModelWithProjection', + tokenizer: 'CLIPTokenizer', + tokenizer_2: 'CLIPTokenizer', + unet: 'UNet2DConditionModel', + scheduler: 'KarrasDiffusionSchedulers', + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool = False, + ): + # call parents parent super skipping parent + super(StableDiffusionXLPipeline, self).__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 1024 + + self.watermark = FakeWatermarker() + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # passed_add_embed_dim = ( + # self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + # ) + # expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + # + # if expected_add_embed_dim != passed_add_embed_dim: + # raise ValueError( + # f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + # ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): # def __init__(self, *args, **kwargs): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 01533d41..d624c5e3 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -6,25 +6,29 @@ import os from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file +from torch.amp import autocast from tqdm import tqdm from torchvision.transforms import Resize -from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ - convert_vae_state_dict +from library.sdxl_train_util import _load_target_model as load_sdxl_target_model from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors -from toolkit.paths import REPOS_ROOT +from toolkit.model_util_sdxl import load_models_from_sdxl_checkpoint +from toolkit.paths import REPOS_ROOT, MODELS_PATH from toolkit.train_tools import get_torch_dtype, apply_noise_offset sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +sys.path.append(os.path.join(REPOS_ROOT, 'sd-scripts')) from leco import train_util import torch -from library import model_util -from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl +from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2 +from library import model_util, train_util as kohya_train_util, sdxl_train_util +from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl, DIFFUSERS_SDXL_UNET_CONFIG from diffusers.schedulers import DDPMScheduler -from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ + HackedStableDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import diffusers @@ -76,6 +80,8 @@ class PromptEmbeds: return self +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + # if is type checking if typing.TYPE_CHECKING: from diffusers import \ @@ -83,7 +89,6 @@ if typing.TYPE_CHECKING: AutoencoderKL, \ UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers - from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection class StableDiffusion: @@ -104,7 +109,7 @@ class StableDiffusion: device, model_config: ModelConfig, dtype='fp16', - custom_pipeline=None + custom_pipeline=None, ): self.custom_pipeline = custom_pipeline self.device = device @@ -116,7 +121,7 @@ class StableDiffusion: # sdxl stuff self.logit_scale = None - self.ckppt_info = None + self.ckpt_info = None self.is_loaded = False # to hold network if there is one @@ -151,35 +156,101 @@ class StableDiffusion: model_path = get_model_path_from_url(self.model_config.name_or_path) if self.model_config.is_xl: - if self.custom_pipeline is not None: - pipln = self.custom_pipeline - else: - pipln = CustomStableDiffusionXLPipeline + # load from kohya + # ( + # load_stable_diffusion_format, + # text_encoder1, + # text_encoder2, + # vae, + # unet, + # logit_scale, + # ckpt_info, + # ) = load_sdxl_target_model( + # model_path, + # self.model_config.vae_path if self.model_config.vae_path is not None else model_path, + # 'sdxl', + # self.dtype, + # self.device, + # ) - # see if path exists - if not os.path.exists(model_path): - # try to load with default diffusers - pipe = pipln.from_pretrained( - model_path, - dtype=dtype, - scheduler_type='ddpm', - device=self.device_torch, - ).to(self.device_torch) - else: - pipe = pipln.from_single_file( - model_path, - dtype=dtype, - scheduler_type='ddpm', - device=self.device_torch, - ).to(self.device_torch) + ( + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = load_models_from_sdxl_checkpoint( + 'sdxl', + model_path, + self.device + ) - text_encoders = [pipe.text_encoder, pipe.text_encoder_2] - tokenizer = [pipe.tokenizer, pipe.tokenizer_2] - for text_encoder in text_encoders: - text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - text_encoder = text_encoders + class Config: + def __init__(self): + # add all items from DIFFUSERS_SDXL_UNET_CONFIG as attributes + for k, v in DIFFUSERS_SDXL_UNET_CONFIG.items(): + setattr(self, k, v) + + # add diffusers stuff + unet.config = Config() + + if self.model_config.vae_path is not None: + vae = model_util.load_vae(self.model_config.vae_path, self.dtype) + print("additional VAE loaded") + + text_encoder1, text_encoder2, unet = kohya_train_util.transform_models_if_DDP( + [text_encoder1, text_encoder2, unet] + ) + class Args: + tokenizer_cache_dir = os.path.join(MODELS_PATH, 'clip_cache') + + args = Args() + os.makedirs(args.tokenizer_cache_dir, exist_ok=True) + + self.tokenizer = sdxl_train_util.load_tokenizers(args) + # tokenizer1 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") + # tokenizer2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2") + + gc.collect() + torch.cuda.empty_cache() + self.vae = vae + self.unet = unet.to(self.device_torch, dtype=dtype) + self.text_encoder = [text_encoder1, text_encoder2] + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + + + # if self.custom_pipeline is not None: + # pipln = self.custom_pipeline + # else: + # pipln = CustomStableDiffusionXLPipeline + # + # # see if path exists + # if not os.path.exists(model_path): + # # try to load with default diffusers + # pipe = pipln.from_pretrained( + # model_path, + # dtype=dtype, + # scheduler_type='ddpm', + # device=self.device_torch, + # ).to(self.device_torch) + # else: + # pipe = pipln.from_single_file( + # model_path, + # dtype=dtype, + # scheduler_type='ddpm', + # device=self.device_torch, + # ).to(self.device_torch) + # + # text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + # tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + # for text_encoder in text_encoders: + # text_encoder.to(self.device_torch, dtype=dtype) + # text_encoder.requires_grad_(False) + # text_encoder.eval() + # text_encoder = text_encoders else: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -215,22 +286,22 @@ class StableDiffusion: text_encoder.requires_grad_(False) text_encoder.eval() tokenizer = pipe.tokenizer + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.pipeline = pipe - # scheduler doesn't get set sometimes, so we set it here - pipe.scheduler = scheduler + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = scheduler - self.unet = pipe.unet - self.noise_scheduler = pipe.scheduler - self.vae = pipe.vae.to(self.device_torch, dtype=dtype) + self.unet = pipe.unet + self.vae = pipe.vae.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = scheduler self.vae.eval() self.vae.requires_grad_(False) self.unet.to(self.device_torch, dtype=dtype) self.unet.requires_grad_(False) self.unet.eval() - - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.pipeline = pipe self.is_loaded = True def generate_images(self, image_configs: List[GenerateImageConfig]): @@ -265,7 +336,7 @@ class StableDiffusion: # TODO add clip skip if self.is_xl: - pipeline = StableDiffusionXLPipeline( + pipeline = HackedStableDiffusionXLPipeline( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder[0], @@ -310,17 +381,18 @@ class StableDiffusion: torch.cuda.manual_seed(gen_config.seed) if self.is_xl: - img = pipeline( - prompt=gen_config.prompt, - prompt_2=gen_config.prompt_2, - negative_prompt=gen_config.negative_prompt, - negative_prompt_2=gen_config.negative_prompt_2, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - guidance_rescale=gen_config.guidance_rescale, - ).images[0] + with autocast('cuda'): + img = pipeline( + prompt=gen_config.prompt, + prompt_2=gen_config.prompt_2, + negative_prompt=gen_config.negative_prompt, + negative_prompt_2=gen_config.negative_prompt_2, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=gen_config.guidance_rescale, + ).images[0] else: img = pipeline( prompt=gen_config.prompt,