mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 00:39:22 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -279,6 +279,20 @@ class StableDiffusion:
|
||||
self.load_refiner()
|
||||
self.is_loaded = True
|
||||
|
||||
def te_train(self):
|
||||
if isinstance(self.text_encoder, list):
|
||||
for te in self.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.text_encoder.train()
|
||||
|
||||
def te_eval(self):
|
||||
if isinstance(self.text_encoder, list):
|
||||
for te in self.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.text_encoder.eval()
|
||||
|
||||
def load_refiner(self):
|
||||
# for now, we are just going to rely on the TE from the base model
|
||||
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||
@@ -721,6 +735,7 @@ class StableDiffusion:
|
||||
add_time_ids=None,
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=True,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -764,6 +779,8 @@ class StableDiffusion:
|
||||
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
@@ -859,7 +876,7 @@ class StableDiffusion:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
@@ -903,7 +920,7 @@ class StableDiffusion:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
**kwargs,
|
||||
@@ -924,6 +941,15 @@ class StableDiffusion:
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||
# // sometimes they are on the wrong device, no idea why
|
||||
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
||||
try:
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
@@ -955,10 +981,12 @@ class StableDiffusion:
|
||||
add_time_ids=None,
|
||||
bleed_ratio: float = 0.5,
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
is_input_scaled=False,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
|
||||
|
||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||
timestep = timestep.unsqueeze_(0)
|
||||
noise_pred = self.predict_noise(
|
||||
@@ -967,6 +995,7 @@ class StableDiffusion:
|
||||
timestep,
|
||||
guidance_scale=guidance_scale,
|
||||
add_time_ids=add_time_ids,
|
||||
is_input_scaled=is_input_scaled,
|
||||
**kwargs,
|
||||
)
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
@@ -977,6 +1006,9 @@ class StableDiffusion:
|
||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
||||
|
||||
# only skip first scaling
|
||||
is_input_scaled = False
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user