Moving over to a highly difussers flow for xl

This commit is contained in:
Jaret Burkett
2023-07-26 18:25:01 -06:00
parent d3ad195b51
commit 2305e55c82
2 changed files with 46 additions and 9 deletions

View File

@@ -3,8 +3,10 @@ import time
from collections import OrderedDict
import os
import diffusers
from safetensors import safe_open
from library import sdxl_train_util, sdxl_model_util
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -14,7 +16,7 @@ import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
@@ -51,12 +53,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
self.first_sample_config = SampleConfig(
**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
self.sd: 'StableDiffusion' = None
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
# added later
self.network = None
@@ -223,7 +230,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
dict = OrderedDict({
"training_info": self.get_training_info()
})
if self.model_config.is_v2:
dict['ss_v2'] = True
if self.model_config.is_xl:
dict['ss_base_model_version'] = 'sdxl_1.0'
dict['ss_output_name'] = self.job.name
self.add_meta(dict)
def get_training_info(self):
info = OrderedDict({
@@ -473,11 +491,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
if self.model_config.is_xl:
tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
pipe = StableDiffusionXLPipeline.from_single_file(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
weight_dtype=dtype,
dtype=dtype,
scheduler_type='pndm',
device=self.device_torch
)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
unet = pipe.unet
noise_scheduler = pipe.scheduler
vae = pipe.vae.to('cpu', dtype=dtype)
vae.eval()
vae.set_use_memory_efficient_attention_xformers(True)
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
@@ -485,6 +512,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder.eval()
text_encoder = text_encoders
tokenizer = tokenizer
del pipe
flush()
else:
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
self.model_config.name_or_path,
@@ -495,11 +527,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
flush()
# just for now or of we want to load a custom one
# put on cpu for now, we only need it when sampling
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
# vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
# vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
unet.to(self.device_torch, dtype=dtype)
@@ -524,7 +560,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
conv_alpha=self.network_config.alpha if conv is not None else None,
)
self.network.force_to(self.device_torch, dtype=dtype)
self.network.apply_to(

View File

@@ -13,3 +13,5 @@ kornia
invisible-watermark
einops
accelerate
toml
albumentations