Tied in ant tested TI script

This commit is contained in:
Jaret Burkett
2023-08-23 13:26:28 -06:00
parent 2e6c55c720
commit d298240cec
6 changed files with 89 additions and 58 deletions

View File

@@ -435,7 +435,7 @@ class StableDiffusion:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
latents.shape[0], # batch size
1, # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
@@ -506,6 +506,17 @@ class StableDiffusion:
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor):
ts_bs = timestep.shape[0]
if ts_bs != latent_model_input.shape[0]:
if ts_bs == 1:
timestep = torch.cat([timestep] * latent_model_input.shape[0])
elif ts_bs * 2 == latent_model_input.shape[0]:
timestep = torch.cat([timestep] * 2)
else:
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# predict the noise residual
noise_pred = self.unet(
latent_model_input,