mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
More timestep fixes
This commit is contained in:
@@ -331,7 +331,7 @@ class SliderConfig:
|
|||||||
print(f"Building slider targets")
|
print(f"Building slider targets")
|
||||||
for target in targets:
|
for target in targets:
|
||||||
if target.shuffle:
|
if target.shuffle:
|
||||||
target_permutations = get_slider_target_permutations(target, max_permutations=100)
|
target_permutations = get_slider_target_permutations(target, max_permutations=8)
|
||||||
self.targets = self.targets + target_permutations
|
self.targets = self.targets + target_permutations
|
||||||
else:
|
else:
|
||||||
self.targets.append(target)
|
self.targets.append(target)
|
||||||
|
|||||||
@@ -682,15 +682,30 @@ class StableDiffusion:
|
|||||||
text_embeddings = text_embeddings.to(self.device_torch)
|
text_embeddings = text_embeddings.to(self.device_torch)
|
||||||
timestep = timestep.to(self.device_torch)
|
timestep = timestep.to(self.device_torch)
|
||||||
|
|
||||||
|
# if timestep is zero dim, unsqueeze it
|
||||||
|
if len(timestep.shape) == 0:
|
||||||
|
timestep = timestep.unsqueeze(0)
|
||||||
|
|
||||||
|
# if we only have 1 timestep, we can just use the same timestep for all
|
||||||
|
if timestep.shape[0] == 1 and latents.shape[0] > 1:
|
||||||
|
# check if it is rank 1 or 2
|
||||||
|
if len(timestep.shape) == 1:
|
||||||
|
timestep = timestep.repeat(latents.shape[0])
|
||||||
|
else:
|
||||||
|
timestep = timestep.repeat(latents.shape[0], 0)
|
||||||
|
|
||||||
|
|
||||||
def scale_model_input(model_input, timestep_tensor):
|
def scale_model_input(model_input, timestep_tensor):
|
||||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
out_chunks = []
|
out_chunks = []
|
||||||
|
# unsqueeze if timestep is zero dim
|
||||||
for idx in range(model_input.shape[0]):
|
for idx in range(model_input.shape[0]):
|
||||||
# if scheduler has step_index
|
# if scheduler has step_index
|
||||||
if hasattr(self.noise_scheduler, '_step_index'):
|
if hasattr(self.noise_scheduler, '_step_index'):
|
||||||
self.noise_scheduler._step_index = None
|
self.noise_scheduler._step_index = None
|
||||||
out_chunks.append(
|
out_chunks.append(
|
||||||
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
|
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx])
|
||||||
)
|
)
|
||||||
return torch.cat(out_chunks, dim=0)
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|
||||||
@@ -845,6 +860,10 @@ class StableDiffusion:
|
|||||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
out_chunks = []
|
out_chunks = []
|
||||||
|
if len(timestep_chunks) == 1 and len(mi_chunks) > 1:
|
||||||
|
# expand timestep to match
|
||||||
|
timestep_chunks = timestep_chunks * len(mi_chunks)
|
||||||
|
|
||||||
for idx in range(model_input.shape[0]):
|
for idx in range(model_input.shape[0]):
|
||||||
# Reset it so it is unique for the
|
# Reset it so it is unique for the
|
||||||
if hasattr(self.noise_scheduler, '_step_index'):
|
if hasattr(self.noise_scheduler, '_step_index'):
|
||||||
|
|||||||
Reference in New Issue
Block a user