mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-22 23:39:21 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -435,7 +435,7 @@ class StableDiffusion:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
latents.shape[0], # batch size
|
||||
1, # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
@@ -506,6 +506,17 @@ class StableDiffusion:
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
ts_bs = timestep.shape[0]
|
||||
if ts_bs != latent_model_input.shape[0]:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
|
||||
Reference in New Issue
Block a user