diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b7aedd35..8fdb2670 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -331,7 +331,7 @@ class SliderConfig: print(f"Building slider targets") for target in targets: if target.shuffle: - target_permutations = get_slider_target_permutations(target, max_permutations=100) + target_permutations = get_slider_target_permutations(target, max_permutations=8) self.targets = self.targets + target_permutations else: self.targets.append(target) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9ee4d995..39114ee6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -682,15 +682,30 @@ class StableDiffusion: text_embeddings = text_embeddings.to(self.device_torch) timestep = timestep.to(self.device_torch) + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + def scale_model_input(model_input, timestep_tensor): mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] + # unsqueeze if timestep is zero dim for idx in range(model_input.shape[0]): # if scheduler has step_index if hasattr(self.noise_scheduler, '_step_index'): self.noise_scheduler._step_index = None out_chunks.append( - self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx]) + self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) ) return torch.cat(out_chunks, dim=0) @@ -845,6 +860,10 @@ class StableDiffusion: latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + for idx in range(model_input.shape[0]): # Reset it so it is unique for the if hasattr(self.noise_scheduler, '_step_index'):