diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 32b55e22..e4105ff8 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -3,8 +3,10 @@ import time from collections import OrderedDict import os +import diffusers from safetensors import safe_open +from library import sdxl_train_util, sdxl_model_util from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -14,7 +16,7 @@ import sys sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors @@ -51,12 +53,17 @@ class BaseSDTrainProcess(BaseTrainProcess): self.model_config = ModelConfig(**self.get_conf('model', {})) self.save_config = SaveConfig(**self.get_conf('save', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {})) - self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config + self.first_sample_config = SampleConfig( + **self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None self.sd: 'StableDiffusion' = None + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + # added later self.network = None @@ -223,7 +230,18 @@ class BaseSDTrainProcess(BaseTrainProcess): # self.sd.tokenizer.to(original_device_dict['tokenizer']) def update_training_metadata(self): - self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + dict = OrderedDict({ + "training_info": self.get_training_info() + }) + if self.model_config.is_v2: + dict['ss_v2'] = True + + if self.model_config.is_xl: + dict['ss_base_model_version'] = 'sdxl_1.0' + + dict['ss_output_name'] = self.job.name + + self.add_meta(dict) def get_training_info(self): info = OrderedDict({ @@ -473,11 +491,20 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) if self.model_config.is_xl: - tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl( + + pipe = StableDiffusionXLPipeline.from_single_file( self.model_config.name_or_path, - scheduler_name=self.train_config.noise_scheduler, - weight_dtype=dtype, + dtype=dtype, + scheduler_type='pndm', + device=self.device_torch ) + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + unet = pipe.unet + noise_scheduler = pipe.scheduler + vae = pipe.vae.to('cpu', dtype=dtype) + vae.eval() + vae.set_use_memory_efficient_attention_xformers(True) for text_encoder in text_encoders: text_encoder.to(self.device_torch, dtype=dtype) @@ -485,6 +512,11 @@ class BaseSDTrainProcess(BaseTrainProcess): text_encoder.eval() text_encoder = text_encoders + tokenizer = tokenizer + del pipe + flush() + + else: tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( self.model_config.name_or_path, @@ -495,11 +527,15 @@ class BaseSDTrainProcess(BaseTrainProcess): text_encoder.to(self.device_torch, dtype=dtype) text_encoder.eval() + vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) + vae.eval() + flush() + # just for now or of we want to load a custom one # put on cpu for now, we only need it when sampling - vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) - vae.eval() + # vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) + # vae.eval() self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl) unet.to(self.device_torch, dtype=dtype) @@ -524,7 +560,6 @@ class BaseSDTrainProcess(BaseTrainProcess): conv_alpha=self.network_config.alpha if conv is not None else None, ) - self.network.force_to(self.device_torch, dtype=dtype) self.network.apply_to( diff --git a/requirements.txt b/requirements.txt index df8cf9ea..c1938431 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ kornia invisible-watermark einops accelerate +toml +albumentations \ No newline at end of file