mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
WIP porting to kohya-sdxl. So much to do.
This commit is contained in:
@@ -15,4 +15,4 @@ accelerate
|
||||
toml
|
||||
albumentations
|
||||
pydantic
|
||||
omegaconf
|
||||
omegaconf
|
||||
|
||||
@@ -76,6 +76,7 @@ class ModelConfig:
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path: str = kwargs.get('vae_path', None)
|
||||
|
||||
if self.name_or_path is None:
|
||||
raise ValueError('name_or_path must be specified')
|
||||
|
||||
130
toolkit/model_util_sdxl.py
Normal file
130
toolkit/model_util_sdxl.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTextConfig, CLIPTextModel
|
||||
|
||||
from library import model_util, sdxl_original_unet
|
||||
from library.sdxl_model_util import convert_sdxl_text_encoder_2_checkpoint
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
# model_version is reserved for future use
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
epoch = None
|
||||
global_step = None
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint.get("epoch", 0)
|
||||
global_step = checkpoint.get("global_step", 0)
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
epoch = 0
|
||||
global_step = 0
|
||||
checkpoint = None
|
||||
|
||||
# U-Net
|
||||
print("building U-Net")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
print("loading U-Net from checkpoint")
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = unet.load_state_dict(unet_sd)
|
||||
print("U-Net: ", info)
|
||||
del unet_sd
|
||||
|
||||
# Text Encoders
|
||||
print("building text encoders")
|
||||
|
||||
# Text Encoder 1 is same to Stability AI's SDXL
|
||||
text_model1_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
||||
|
||||
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
print("loading text encoders from checkpoint")
|
||||
te1_sd = {}
|
||||
te2_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.endswith("text_model.embeddings.position_ids"):
|
||||
# skip position_ids
|
||||
state_dict.pop(k)
|
||||
elif k.startswith("conditioner.embedders.0.transformer."):
|
||||
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
|
||||
|
||||
info1 = text_model1.load_state_dict(te1_sd)
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
# remove text_model.embeddings.position_ids"
|
||||
converted_sd.pop("text_model.embeddings.position_ids")
|
||||
info2 = text_model2.load_state_dict(converted_sd)
|
||||
print("text encoder 2:", info2)
|
||||
|
||||
# prepare vae
|
||||
print("building VAE")
|
||||
vae_config = model_util.create_vae_diffusers_config()
|
||||
vae = AutoencoderKL(**vae_config) # .to(device)
|
||||
|
||||
print("loading VAE from checkpoint")
|
||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||
print("VAE:", info)
|
||||
|
||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
@@ -2,9 +2,76 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
|
||||
|
||||
class FakeWatermarker(StableDiffusionXLWatermarker):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def apply_watermark(self, image):
|
||||
return image
|
||||
|
||||
|
||||
class HackedStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: 'AutoencoderKL',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
text_encoder_2: 'CLIPTextModelWithProjection',
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
tokenizer_2: 'CLIPTokenizer',
|
||||
unet: 'UNet2DConditionModel',
|
||||
scheduler: 'KarrasDiffusionSchedulers',
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: bool = False,
|
||||
):
|
||||
# call parents parent super skipping parent
|
||||
super(StableDiffusionXLPipeline, self).__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = 1024
|
||||
|
||||
self.watermark = FakeWatermarker()
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
# passed_add_embed_dim = (
|
||||
# self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
# )
|
||||
# expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
#
|
||||
# if expected_add_embed_dim != passed_add_embed_dim:
|
||||
# raise ValueError(
|
||||
# f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
# )
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -6,25 +6,29 @@ import os
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file
|
||||
from torch.amp import autocast
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
from library.sdxl_train_util import _load_target_model as load_sdxl_target_model
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.model_util_sdxl import load_models_from_sdxl_checkpoint
|
||||
from toolkit.paths import REPOS_ROOT, MODELS_PATH
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'sd-scripts'))
|
||||
from leco import train_util
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2
|
||||
from library import model_util, train_util as kohya_train_util, sdxl_train_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl, DIFFUSERS_SDXL_UNET_CONFIG
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
HackedStableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import diffusers
|
||||
|
||||
@@ -76,6 +80,8 @@ class PromptEmbeds:
|
||||
return self
|
||||
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import \
|
||||
@@ -83,7 +89,6 @@ if typing.TYPE_CHECKING:
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
@@ -104,7 +109,7 @@ class StableDiffusion:
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='fp16',
|
||||
custom_pipeline=None
|
||||
custom_pipeline=None,
|
||||
):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
@@ -116,7 +121,7 @@ class StableDiffusion:
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
self.ckpt_info = None
|
||||
self.is_loaded = False
|
||||
|
||||
# to hold network if there is one
|
||||
@@ -151,35 +156,101 @@ class StableDiffusion:
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
# load from kohya
|
||||
# (
|
||||
# load_stable_diffusion_format,
|
||||
# text_encoder1,
|
||||
# text_encoder2,
|
||||
# vae,
|
||||
# unet,
|
||||
# logit_scale,
|
||||
# ckpt_info,
|
||||
# ) = load_sdxl_target_model(
|
||||
# model_path,
|
||||
# self.model_config.vae_path if self.model_config.vae_path is not None else model_path,
|
||||
# 'sdxl',
|
||||
# self.dtype,
|
||||
# self.device,
|
||||
# )
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = load_models_from_sdxl_checkpoint(
|
||||
'sdxl',
|
||||
model_path,
|
||||
self.device
|
||||
)
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# add all items from DIFFUSERS_SDXL_UNET_CONFIG as attributes
|
||||
for k, v in DIFFUSERS_SDXL_UNET_CONFIG.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
# add diffusers stuff
|
||||
unet.config = Config()
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
vae = model_util.load_vae(self.model_config.vae_path, self.dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
text_encoder1, text_encoder2, unet = kohya_train_util.transform_models_if_DDP(
|
||||
[text_encoder1, text_encoder2, unet]
|
||||
)
|
||||
class Args:
|
||||
tokenizer_cache_dir = os.path.join(MODELS_PATH, 'clip_cache')
|
||||
|
||||
args = Args()
|
||||
os.makedirs(args.tokenizer_cache_dir, exist_ok=True)
|
||||
|
||||
self.tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
# tokenizer1 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
# tokenizer2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.vae = vae
|
||||
self.unet = unet.to(self.device_torch, dtype=dtype)
|
||||
self.text_encoder = [text_encoder1, text_encoder2]
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
|
||||
|
||||
# if self.custom_pipeline is not None:
|
||||
# pipln = self.custom_pipeline
|
||||
# else:
|
||||
# pipln = CustomStableDiffusionXLPipeline
|
||||
#
|
||||
# # see if path exists
|
||||
# if not os.path.exists(model_path):
|
||||
# # try to load with default diffusers
|
||||
# pipe = pipln.from_pretrained(
|
||||
# model_path,
|
||||
# dtype=dtype,
|
||||
# scheduler_type='ddpm',
|
||||
# device=self.device_torch,
|
||||
# ).to(self.device_torch)
|
||||
# else:
|
||||
# pipe = pipln.from_single_file(
|
||||
# model_path,
|
||||
# dtype=dtype,
|
||||
# scheduler_type='ddpm',
|
||||
# device=self.device_torch,
|
||||
# ).to(self.device_torch)
|
||||
#
|
||||
# text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
# tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
# for text_encoder in text_encoders:
|
||||
# text_encoder.to(self.device_torch, dtype=dtype)
|
||||
# text_encoder.requires_grad_(False)
|
||||
# text_encoder.eval()
|
||||
# text_encoder = text_encoders
|
||||
else:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -215,22 +286,22 @@ class StableDiffusion:
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
tokenizer = pipe.tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.noise_scheduler = pipe.scheduler
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.unet = pipe.unet
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.noise_scheduler = scheduler
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
self.unet.to(self.device_torch, dtype=dtype)
|
||||
self.unet.requires_grad_(False)
|
||||
self.unet.eval()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
||||
@@ -265,7 +336,7 @@ class StableDiffusion:
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
pipeline = HackedStableDiffusionXLPipeline(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
@@ -310,17 +381,18 @@ class StableDiffusion:
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
if self.is_xl:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=gen_config.guidance_rescale,
|
||||
).images[0]
|
||||
with autocast('cuda'):
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=gen_config.guidance_rescale,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
|
||||
Reference in New Issue
Block a user