mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Moving over to a highly difussers flow for xl
This commit is contained in:
@@ -3,8 +3,10 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
import diffusers
|
||||
from safetensors import safe_open
|
||||
|
||||
from library import sdxl_train_util, sdxl_model_util
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -14,7 +16,7 @@ import sys
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -51,12 +53,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
|
||||
self.first_sample_config = SampleConfig(
|
||||
**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.sd: 'StableDiffusion' = None
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
|
||||
# added later
|
||||
self.network = None
|
||||
|
||||
@@ -223,7 +230,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def update_training_metadata(self):
|
||||
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
|
||||
dict = OrderedDict({
|
||||
"training_info": self.get_training_info()
|
||||
})
|
||||
if self.model_config.is_v2:
|
||||
dict['ss_v2'] = True
|
||||
|
||||
if self.model_config.is_xl:
|
||||
dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
|
||||
dict['ss_output_name'] = self.job.name
|
||||
|
||||
self.add_meta(dict)
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
@@ -473,11 +491,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
scheduler_name=self.train_config.noise_scheduler,
|
||||
weight_dtype=dtype,
|
||||
dtype=dtype,
|
||||
scheduler_type='pndm',
|
||||
device=self.device_torch
|
||||
)
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
unet = pipe.unet
|
||||
noise_scheduler = pipe.scheduler
|
||||
vae = pipe.vae.to('cpu', dtype=dtype)
|
||||
vae.eval()
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
@@ -485,6 +512,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder.eval()
|
||||
|
||||
text_encoder = text_encoders
|
||||
tokenizer = tokenizer
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
|
||||
else:
|
||||
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
|
||||
self.model_config.name_or_path,
|
||||
@@ -495,11 +527,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.eval()
|
||||
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
vae.eval()
|
||||
flush()
|
||||
|
||||
|
||||
# just for now or of we want to load a custom one
|
||||
# put on cpu for now, we only need it when sampling
|
||||
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
vae.eval()
|
||||
# vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
# vae.eval()
|
||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
|
||||
|
||||
unet.to(self.device_torch, dtype=dtype)
|
||||
@@ -524,7 +560,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
conv_alpha=self.network_config.alpha if conv is not None else None,
|
||||
)
|
||||
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.network.apply_to(
|
||||
|
||||
@@ -13,3 +13,5 @@ kornia
|
||||
invisible-watermark
|
||||
einops
|
||||
accelerate
|
||||
toml
|
||||
albumentations
|
||||
Reference in New Issue
Block a user