mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
More timestep fixes
This commit is contained in:
@@ -331,7 +331,7 @@ class SliderConfig:
|
||||
print(f"Building slider targets")
|
||||
for target in targets:
|
||||
if target.shuffle:
|
||||
target_permutations = get_slider_target_permutations(target, max_permutations=100)
|
||||
target_permutations = get_slider_target_permutations(target, max_permutations=8)
|
||||
self.targets = self.targets + target_permutations
|
||||
else:
|
||||
self.targets.append(target)
|
||||
|
||||
@@ -682,15 +682,30 @@ class StableDiffusion:
|
||||
text_embeddings = text_embeddings.to(self.device_torch)
|
||||
timestep = timestep.to(self.device_torch)
|
||||
|
||||
# if timestep is zero dim, unsqueeze it
|
||||
if len(timestep.shape) == 0:
|
||||
timestep = timestep.unsqueeze(0)
|
||||
|
||||
# if we only have 1 timestep, we can just use the same timestep for all
|
||||
if timestep.shape[0] == 1 and latents.shape[0] > 1:
|
||||
# check if it is rank 1 or 2
|
||||
if len(timestep.shape) == 1:
|
||||
timestep = timestep.repeat(latents.shape[0])
|
||||
else:
|
||||
timestep = timestep.repeat(latents.shape[0], 0)
|
||||
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
# unsqueeze if timestep is zero dim
|
||||
for idx in range(model_input.shape[0]):
|
||||
# if scheduler has step_index
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
self.noise_scheduler._step_index = None
|
||||
out_chunks.append(
|
||||
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
|
||||
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx])
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
@@ -845,6 +860,10 @@ class StableDiffusion:
|
||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
if len(timestep_chunks) == 1 and len(mi_chunks) > 1:
|
||||
# expand timestep to match
|
||||
timestep_chunks = timestep_chunks * len(mi_chunks)
|
||||
|
||||
for idx in range(model_input.shape[0]):
|
||||
# Reset it so it is unique for the
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
|
||||
Reference in New Issue
Block a user