mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Pipelines working on SDXL for noise prediction
This commit is contained in:
@@ -18,7 +18,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -500,13 +500,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
# do our own scheduler
|
||||
scheduler = KDPM2DiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
|
||||
pipe = CustomStableDiffusionXLPipeline.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
pipe.scheduler = scheduler
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
unet = pipe.unet
|
||||
@@ -637,10 +645,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar = tqdm(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
# set it to our current step in case it was updated from a load
|
||||
self.progress_bar.update(self.step_num)
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
Reference in New Issue
Block a user