More timestep fixes

This commit is contained in:
Jaret Burkett
2023-11-28 10:34:08 -07:00
parent 792a5e37e2
commit c2a4b8e058
2 changed files with 21 additions and 2 deletions

View File

@@ -331,7 +331,7 @@ class SliderConfig:
print(f"Building slider targets")
for target in targets:
if target.shuffle:
target_permutations = get_slider_target_permutations(target, max_permutations=100)
target_permutations = get_slider_target_permutations(target, max_permutations=8)
self.targets = self.targets + target_permutations
else:
self.targets.append(target)

View File

@@ -682,15 +682,30 @@ class StableDiffusion:
text_embeddings = text_embeddings.to(self.device_torch)
timestep = timestep.to(self.device_torch)
# if timestep is zero dim, unsqueeze it
if len(timestep.shape) == 0:
timestep = timestep.unsqueeze(0)
# if we only have 1 timestep, we can just use the same timestep for all
if timestep.shape[0] == 1 and latents.shape[0] > 1:
# check if it is rank 1 or 2
if len(timestep.shape) == 1:
timestep = timestep.repeat(latents.shape[0])
else:
timestep = timestep.repeat(latents.shape[0], 0)
def scale_model_input(model_input, timestep_tensor):
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_input.shape[0]):
# if scheduler has step_index
if hasattr(self.noise_scheduler, '_step_index'):
self.noise_scheduler._step_index = None
out_chunks.append(
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx])
)
return torch.cat(out_chunks, dim=0)
@@ -845,6 +860,10 @@ class StableDiffusion:
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
if len(timestep_chunks) == 1 and len(mi_chunks) > 1:
# expand timestep to match
timestep_chunks = timestep_chunks * len(mi_chunks)
for idx in range(model_input.shape[0]):
# Reset it so it is unique for the
if hasattr(self.noise_scheduler, '_step_index'):