mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Information trainer
This commit is contained in:
@@ -13,7 +13,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler, PNDMScheduler, \
|
||||
DDIMScheduler, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -38,8 +39,9 @@ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
@@ -271,6 +273,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
|
||||
# Called before the model is loaded
|
||||
def hook_before_model_load(self):
|
||||
@@ -467,18 +470,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
sch = KDPM2DiscreteScheduler
|
||||
# do our own scheduler
|
||||
scheduler = KDPM2DiscreteScheduler(
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
pipe = CustomStableDiffusionXLPipeline.from_single_file(
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
|
||||
@@ -490,7 +499,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
pipe = CustomStableDiffusionPipeline.from_single_file(
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionPipeline
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
@@ -614,7 +627,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if self.has_first_sample_requested:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=False)
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
|
||||
Reference in New Issue
Block a user