WIP. just need to put it here

This commit is contained in:
Jaret Burkett
2023-07-27 01:46:30 -06:00
parent 2305e55c82
commit 6ab8b8b0f1
4 changed files with 279 additions and 63 deletions

View File

@@ -13,10 +13,12 @@ from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.pipelines import CustomStableDiffusionXLPipeline
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
@@ -100,15 +102,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
if self.sd.is_xl:
pipeline = StableDiffusionXLPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=self.sd.noise_scheduler,
)
pipeline = self.sd.pipeline
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
@@ -209,7 +203,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
img.save(output_path)
# clear pipeline and cache to reduce vram usage
del pipeline
if not self.sd.is_xl:
del pipeline
torch.cuda.empty_cache()
# restore training state
@@ -363,6 +358,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
return None
def predict_noise_xl(
self,
latents: torch.FloatTensor,
positive_prompt: str,
negative_prompt: str,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
add_time_ids=None,
**kwargs,
):
pass
def predict_noise(
self,
latents: torch.FloatTensor,
@@ -492,12 +501,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.model_config.is_xl:
pipe = StableDiffusionXLPipeline.from_single_file(
pipe = CustomStableDiffusionXLPipeline.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='pndm',
scheduler_type='dpm',
device=self.device_torch
)
).to(self.device_torch)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
unet = pipe.unet
@@ -513,7 +522,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder = text_encoders
tokenizer = tokenizer
del pipe
flush()
@@ -529,6 +537,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder.eval()
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
pipe = None
flush()
@@ -536,7 +545,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# put on cpu for now, we only need it when sampling
# vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
# vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl, pipeline=pipe)
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers: