Merge branch 'sdxl' into WIP

# Conflicts:
#	jobs/process/BaseSDTrainProcess.py
#	jobs/process/TrainSliderProcess.py
This commit is contained in:
Jaret Burkett
2023-07-29 14:29:18 -06:00
2 changed files with 263 additions and 193 deletions

View File

@@ -103,7 +103,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
pipeline = self.sd.pipeline
if self.sd.is_xl:
pipeline = StableDiffusionXLPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=self.sd.noise_scheduler,
)
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -162,16 +182,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
img = pipeline(
prompt=prompt,
prompt_2=prompt,
negative_prompt=neg,
negative_prompt_2=neg,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
).images[0]
if self.sd.is_xl:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
else:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
step_num = ''
if step is not None:
@@ -184,6 +212,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path = os.path.join(sample_folder, filename)
img.save(output_path)
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
# restore training state
@@ -259,12 +289,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
# TODO handle dreambooth, fine tuning, etc
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta
)
self.network.multiplier = prev_multiplier
else:
self.sd.save(
file_path,
@@ -340,19 +373,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
return None
def predict_noise_xl(
self,
latents: torch.FloatTensor,
positive_prompt: str,
negative_prompt: str,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
add_time_ids=None,
**kwargs,
):
pass
def predict_noise(
self,
latents: torch.FloatTensor,