Pipelines working on SDXL for noise prediction

This commit is contained in:
Jaret Burkett
2023-07-27 11:24:33 -06:00
parent 6ab8b8b0f1
commit 596e57a6a6
3 changed files with 162 additions and 39 deletions

View File

@@ -18,7 +18,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
@@ -500,13 +500,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
if self.model_config.is_xl:
# do our own scheduler
scheduler = KDPM2DiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
)
pipe = CustomStableDiffusionXLPipeline.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch
device=self.device_torch,
).to(self.device_torch)
pipe.scheduler = scheduler
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
unet = pipe.unet
@@ -637,10 +645,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar = tqdm(
total=self.train_config.steps,
desc=self.job.name,
leave=True
leave=True,
initial=self.step_num,
iterable=range(0, self.train_config.steps),
)
# set it to our current step in case it was updated from a load
self.progress_bar.update(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# todo handle dataloader here maybe, not sure