WIP porting to kohya-sdxl. So much to do.

This commit is contained in:
Jaret Burkett
2023-08-15 17:07:34 -06:00
parent 55a5fcc7d9
commit 1d2523b978
5 changed files with 328 additions and 58 deletions

View File

@@ -6,25 +6,29 @@ import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file
from torch.amp import autocast
from tqdm import tqdm
from torchvision.transforms import Resize
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict
from library.sdxl_train_util import _load_target_model as load_sdxl_target_model
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.model_util_sdxl import load_models_from_sdxl_checkpoint
from toolkit.paths import REPOS_ROOT, MODELS_PATH
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
sys.path.append(os.path.join(REPOS_ROOT, 'sd-scripts'))
from leco import train_util
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2
from library import model_util, train_util as kohya_train_util, sdxl_train_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl, DIFFUSERS_SDXL_UNET_CONFIG
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
HackedStableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import diffusers
@@ -76,6 +80,8 @@ class PromptEmbeds:
return self
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
# if is type checking
if typing.TYPE_CHECKING:
from diffusers import \
@@ -83,7 +89,6 @@ if typing.TYPE_CHECKING:
AutoencoderKL, \
UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
class StableDiffusion:
@@ -104,7 +109,7 @@ class StableDiffusion:
device,
model_config: ModelConfig,
dtype='fp16',
custom_pipeline=None
custom_pipeline=None,
):
self.custom_pipeline = custom_pipeline
self.device = device
@@ -116,7 +121,7 @@ class StableDiffusion:
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
self.ckpt_info = None
self.is_loaded = False
# to hold network if there is one
@@ -151,35 +156,101 @@ class StableDiffusion:
model_path = get_model_path_from_url(self.model_config.name_or_path)
if self.model_config.is_xl:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionXLPipeline
# load from kohya
# (
# load_stable_diffusion_format,
# text_encoder1,
# text_encoder2,
# vae,
# unet,
# logit_scale,
# ckpt_info,
# ) = load_sdxl_target_model(
# model_path,
# self.model_config.vae_path if self.model_config.vae_path is not None else model_path,
# 'sdxl',
# self.dtype,
# self.device,
# )
# see if path exists
if not os.path.exists(model_path):
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
else:
pipe = pipln.from_single_file(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
(
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = load_models_from_sdxl_checkpoint(
'sdxl',
model_path,
self.device
)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
class Config:
def __init__(self):
# add all items from DIFFUSERS_SDXL_UNET_CONFIG as attributes
for k, v in DIFFUSERS_SDXL_UNET_CONFIG.items():
setattr(self, k, v)
# add diffusers stuff
unet.config = Config()
if self.model_config.vae_path is not None:
vae = model_util.load_vae(self.model_config.vae_path, self.dtype)
print("additional VAE loaded")
text_encoder1, text_encoder2, unet = kohya_train_util.transform_models_if_DDP(
[text_encoder1, text_encoder2, unet]
)
class Args:
tokenizer_cache_dir = os.path.join(MODELS_PATH, 'clip_cache')
args = Args()
os.makedirs(args.tokenizer_cache_dir, exist_ok=True)
self.tokenizer = sdxl_train_util.load_tokenizers(args)
# tokenizer1 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
# tokenizer2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2")
gc.collect()
torch.cuda.empty_cache()
self.vae = vae
self.unet = unet.to(self.device_torch, dtype=dtype)
self.text_encoder = [text_encoder1, text_encoder2]
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
# if self.custom_pipeline is not None:
# pipln = self.custom_pipeline
# else:
# pipln = CustomStableDiffusionXLPipeline
#
# # see if path exists
# if not os.path.exists(model_path):
# # try to load with default diffusers
# pipe = pipln.from_pretrained(
# model_path,
# dtype=dtype,
# scheduler_type='ddpm',
# device=self.device_torch,
# ).to(self.device_torch)
# else:
# pipe = pipln.from_single_file(
# model_path,
# dtype=dtype,
# scheduler_type='ddpm',
# device=self.device_torch,
# ).to(self.device_torch)
#
# text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
# tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
# for text_encoder in text_encoders:
# text_encoder.to(self.device_torch, dtype=dtype)
# text_encoder.requires_grad_(False)
# text_encoder.eval()
# text_encoder = text_encoders
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -215,22 +286,22 @@ class StableDiffusion:
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.unet = pipe.unet
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.noise_scheduler = scheduler
self.vae.eval()
self.vae.requires_grad_(False)
self.unet.to(self.device_torch, dtype=dtype)
self.unet.requires_grad_(False)
self.unet.eval()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
self.is_loaded = True
def generate_images(self, image_configs: List[GenerateImageConfig]):
@@ -265,7 +336,7 @@ class StableDiffusion:
# TODO add clip skip
if self.is_xl:
pipeline = StableDiffusionXLPipeline(
pipeline = HackedStableDiffusionXLPipeline(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
@@ -310,17 +381,18 @@ class StableDiffusion:
torch.cuda.manual_seed(gen_config.seed)
if self.is_xl:
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
).images[0]
with autocast('cuda'):
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
).images[0]
else:
img = pipeline(
prompt=gen_config.prompt,