Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline
StableDiffusionXLImg2ImgPipeline, LCMScheduler
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -279,6 +279,20 @@ class StableDiffusion:
self.load_refiner()
self.is_loaded = True
def te_train(self):
if isinstance(self.text_encoder, list):
for te in self.text_encoder:
te.train()
else:
self.text_encoder.train()
def te_eval(self):
if isinstance(self.text_encoder, list):
for te in self.text_encoder:
te.eval()
else:
self.text_encoder.eval()
def load_refiner(self):
# for now, we are just going to rely on the TE from the base model
# which is TE2 for SDXL and TE for SD (no refiner currently)
@@ -721,6 +735,7 @@ class StableDiffusion:
add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
is_input_scaled=True,
**kwargs,
):
with torch.no_grad():
@@ -764,6 +779,8 @@ class StableDiffusion:
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
@@ -859,7 +876,7 @@ class StableDiffusion:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
@@ -903,7 +920,7 @@ class StableDiffusion:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
@@ -924,6 +941,15 @@ class StableDiffusion:
return noise_pred
def step_scheduler(self, model_input, latent_input, timestep_tensor):
# // sometimes they are on the wrong device, no idea why
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
try:
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
except Exception as e:
pass
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
@@ -955,10 +981,12 @@ class StableDiffusion:
add_time_ids=None,
bleed_ratio: float = 0.5,
bleed_latents: torch.FloatTensor = None,
is_input_scaled=False,
**kwargs,
):
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
for timestep in tqdm(timesteps_to_run, leave=False):
timestep = timestep.unsqueeze_(0)
noise_pred = self.predict_noise(
@@ -967,6 +995,7 @@ class StableDiffusion:
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
is_input_scaled=is_input_scaled,
**kwargs,
)
# some schedulers need to run separately, so do that. (euler for example)
@@ -977,6 +1006,9 @@ class StableDiffusion:
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
# only skip first scaling
is_input_scaled = False
# return latents_steps
return latents