mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
WIP porting to kohya-sdxl. So much to do.
This commit is contained in:
@@ -6,25 +6,29 @@ import os
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file
|
||||
from torch.amp import autocast
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
from library.sdxl_train_util import _load_target_model as load_sdxl_target_model
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.model_util_sdxl import load_models_from_sdxl_checkpoint
|
||||
from toolkit.paths import REPOS_ROOT, MODELS_PATH
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'sd-scripts'))
|
||||
from leco import train_util
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2
|
||||
from library import model_util, train_util as kohya_train_util, sdxl_train_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl, DIFFUSERS_SDXL_UNET_CONFIG
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
HackedStableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import diffusers
|
||||
|
||||
@@ -76,6 +80,8 @@ class PromptEmbeds:
|
||||
return self
|
||||
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import \
|
||||
@@ -83,7 +89,6 @@ if typing.TYPE_CHECKING:
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
@@ -104,7 +109,7 @@ class StableDiffusion:
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='fp16',
|
||||
custom_pipeline=None
|
||||
custom_pipeline=None,
|
||||
):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
@@ -116,7 +121,7 @@ class StableDiffusion:
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
self.ckpt_info = None
|
||||
self.is_loaded = False
|
||||
|
||||
# to hold network if there is one
|
||||
@@ -151,35 +156,101 @@ class StableDiffusion:
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
# load from kohya
|
||||
# (
|
||||
# load_stable_diffusion_format,
|
||||
# text_encoder1,
|
||||
# text_encoder2,
|
||||
# vae,
|
||||
# unet,
|
||||
# logit_scale,
|
||||
# ckpt_info,
|
||||
# ) = load_sdxl_target_model(
|
||||
# model_path,
|
||||
# self.model_config.vae_path if self.model_config.vae_path is not None else model_path,
|
||||
# 'sdxl',
|
||||
# self.dtype,
|
||||
# self.device,
|
||||
# )
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = load_models_from_sdxl_checkpoint(
|
||||
'sdxl',
|
||||
model_path,
|
||||
self.device
|
||||
)
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# add all items from DIFFUSERS_SDXL_UNET_CONFIG as attributes
|
||||
for k, v in DIFFUSERS_SDXL_UNET_CONFIG.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
# add diffusers stuff
|
||||
unet.config = Config()
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
vae = model_util.load_vae(self.model_config.vae_path, self.dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
text_encoder1, text_encoder2, unet = kohya_train_util.transform_models_if_DDP(
|
||||
[text_encoder1, text_encoder2, unet]
|
||||
)
|
||||
class Args:
|
||||
tokenizer_cache_dir = os.path.join(MODELS_PATH, 'clip_cache')
|
||||
|
||||
args = Args()
|
||||
os.makedirs(args.tokenizer_cache_dir, exist_ok=True)
|
||||
|
||||
self.tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
# tokenizer1 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
# tokenizer2 = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer_2")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.vae = vae
|
||||
self.unet = unet.to(self.device_torch, dtype=dtype)
|
||||
self.text_encoder = [text_encoder1, text_encoder2]
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
|
||||
|
||||
# if self.custom_pipeline is not None:
|
||||
# pipln = self.custom_pipeline
|
||||
# else:
|
||||
# pipln = CustomStableDiffusionXLPipeline
|
||||
#
|
||||
# # see if path exists
|
||||
# if not os.path.exists(model_path):
|
||||
# # try to load with default diffusers
|
||||
# pipe = pipln.from_pretrained(
|
||||
# model_path,
|
||||
# dtype=dtype,
|
||||
# scheduler_type='ddpm',
|
||||
# device=self.device_torch,
|
||||
# ).to(self.device_torch)
|
||||
# else:
|
||||
# pipe = pipln.from_single_file(
|
||||
# model_path,
|
||||
# dtype=dtype,
|
||||
# scheduler_type='ddpm',
|
||||
# device=self.device_torch,
|
||||
# ).to(self.device_torch)
|
||||
#
|
||||
# text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
# tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
# for text_encoder in text_encoders:
|
||||
# text_encoder.to(self.device_torch, dtype=dtype)
|
||||
# text_encoder.requires_grad_(False)
|
||||
# text_encoder.eval()
|
||||
# text_encoder = text_encoders
|
||||
else:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -215,22 +286,22 @@ class StableDiffusion:
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
tokenizer = pipe.tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.noise_scheduler = pipe.scheduler
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.unet = pipe.unet
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.noise_scheduler = scheduler
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
self.unet.to(self.device_torch, dtype=dtype)
|
||||
self.unet.requires_grad_(False)
|
||||
self.unet.eval()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig]):
|
||||
@@ -265,7 +336,7 @@ class StableDiffusion:
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
pipeline = HackedStableDiffusionXLPipeline(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
@@ -310,17 +381,18 @@ class StableDiffusion:
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
if self.is_xl:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=gen_config.guidance_rescale,
|
||||
).images[0]
|
||||
with autocast('cuda'):
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=gen_config.guidance_rescale,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
|
||||
Reference in New Issue
Block a user