mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# we have to encode images into latents for now
|
||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||
with torch.no_grad():
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
|
||||
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
||||
target = unaugmented_latents.detach()
|
||||
|
||||
@@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
|
||||
Reference in New Issue
Block a user