Sampling tests and added fixes for cleanups

This commit is contained in:
Jaret Burkett
2023-11-16 08:33:23 -07:00
parent e47006ed70
commit ad50921c41
4 changed files with 495 additions and 5 deletions

View File

@@ -675,6 +675,9 @@ class StableDiffusion:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)
text_embeddings = text_embeddings.to(self.device_torch)
timestep = timestep.to(self.device_torch)
if self.is_xl:
with torch.no_grad():