mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP. just need to put it here
This commit is contained in:
@@ -13,10 +13,12 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DDPMScheduler
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -100,15 +102,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.sd.tokenizer.to(self.device_torch)
|
||||
# TODO add clip skip
|
||||
if self.sd.is_xl:
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=self.sd.noise_scheduler,
|
||||
)
|
||||
pipeline = self.sd.pipeline
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=self.sd.vae,
|
||||
@@ -209,7 +203,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
img.save(output_path)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
if not self.sd.is_xl:
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# restore training state
|
||||
@@ -363,6 +358,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_noise_xl(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
positive_prompt: str,
|
||||
negative_prompt: str,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0.7,
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
@@ -492,12 +501,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if self.model_config.is_xl:
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
pipe = CustomStableDiffusionXLPipeline.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='pndm',
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch
|
||||
)
|
||||
).to(self.device_torch)
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
unet = pipe.unet
|
||||
@@ -513,7 +522,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
text_encoder = text_encoders
|
||||
tokenizer = tokenizer
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
|
||||
@@ -529,6 +537,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder.eval()
|
||||
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
vae.eval()
|
||||
pipe = None
|
||||
flush()
|
||||
|
||||
|
||||
@@ -536,7 +545,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# put on cpu for now, we only need it when sampling
|
||||
# vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
# vae.eval()
|
||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
|
||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl, pipeline=pipe)
|
||||
|
||||
unet.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.xformers:
|
||||
|
||||
Reference in New Issue
Block a user