Added initial support for Hidream E1 training

This commit is contained in:
Jaret Burkett
2025-07-27 15:12:56 -06:00
parent 3f518d9951
commit cefa2ca5fe
6 changed files with 1410 additions and 7 deletions

View File

@@ -1019,7 +1019,7 @@ class BaseModel:
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
images = torch.stack(image_list)
images = torch.stack(image_list).to(device, dtype=dtype)
if isinstance(self.vae, AutoencoderTiny):
latents = self.vae.encode(images, return_dict=False)[0]
else: