diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 2c2ffd59..018c3638 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -408,7 +408,7 @@ class LatentCachingMixin: if to_memory: # keep it in memory - file_item._encoded_latent = latent.to('cpu', dtype=self.sd.dtype) + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) flush(garbage_collect=False) file_item.is_latent_cached = True