mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Numerous fixes for time sampling. Still not perfect
This commit is contained in:
@@ -193,6 +193,9 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
if 'vae' in load_args and load_args['vae'] is not None:
|
||||
pipe.vae = load_args['vae']
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
@@ -679,6 +682,18 @@ class StableDiffusion:
|
||||
text_embeddings = text_embeddings.to(self.device_torch)
|
||||
timestep = timestep.to(self.device_torch)
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
for idx in range(model_input.shape[0]):
|
||||
# if scheduler has step_index
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
self.noise_scheduler._step_index = None
|
||||
out_chunks.append(
|
||||
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
if self.is_xl:
|
||||
with torch.no_grad():
|
||||
# 16, 6 for bs of 4
|
||||
@@ -691,10 +706,11 @@ class StableDiffusion:
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
latent_model_input = scale_model_input(latent_model_input, timestep)
|
||||
|
||||
added_cond_kwargs = {
|
||||
# todo can we zero here the second text encoder? or match a blank string?
|
||||
@@ -784,10 +800,11 @@ class StableDiffusion:
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
latent_model_input = scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
||||
@@ -823,6 +840,23 @@ class StableDiffusion:
|
||||
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
for idx in range(model_input.shape[0]):
|
||||
# Reset it so it is unique for the
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
self.noise_scheduler._step_index = None
|
||||
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
|
||||
self.noise_scheduler.is_scale_input_called = True
|
||||
out_chunks.append(
|
||||
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||
0]
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
def diffuse_some_steps(
|
||||
self,
|
||||
@@ -839,6 +873,7 @@ class StableDiffusion:
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
|
||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||
timestep = timestep.unsqueeze_(0)
|
||||
noise_pred = self.predict_noise(
|
||||
latents,
|
||||
text_embeddings,
|
||||
@@ -847,7 +882,9 @@ class StableDiffusion:
|
||||
add_time_ids=add_time_ids,
|
||||
**kwargs,
|
||||
)
|
||||
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
|
||||
latents = self.step_scheduler(noise_pred, latents, timestep)
|
||||
|
||||
# if not last step, and bleeding, bleed in some latents
|
||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||
|
||||
Reference in New Issue
Block a user