mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Scale target vae latent before targeting it
This commit is contained in:
@@ -411,6 +411,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
img = img
|
||||
# latent = self.vae.encode(img).latent_dist.sample()
|
||||
|
||||
target_latent = None
|
||||
if self.target_latent_vae is not None:
|
||||
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
|
||||
target_input_size = (int(img.shape[2] * target_input_scale), int(img.shape[3] * target_input_scale))
|
||||
# resize to target input size
|
||||
target_input_batch = Resize(target_input_size)(img).to(self.device, dtype=torch.float32)
|
||||
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
|
||||
shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0
|
||||
target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift)
|
||||
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae.encode(img, return_dict=False)[0]
|
||||
|
||||
latent_img = latent.clone()
|
||||
@@ -443,6 +454,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# convert to pillow image
|
||||
latent_img = Image.fromarray(latent_img)
|
||||
|
||||
if target_latent is not None:
|
||||
latent = target_latent.to(latent.device, dtype=latent.dtype)
|
||||
|
||||
decoded = self.vae.decode(latent).sample
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
@@ -589,7 +603,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
@@ -670,6 +684,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# resize to target input size
|
||||
target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32)
|
||||
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
|
||||
# shift scale it
|
||||
shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0
|
||||
target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift)
|
||||
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
|
||||
@@ -781,9 +798,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
|
||||
if self.lpips_weight > 0:
|
||||
lpips_loss = self.lpips_loss(
|
||||
pred.clamp(-1, 1),
|
||||
batch.clamp(-1, 1)
|
||||
).mean() * self.lpips_weight
|
||||
pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16),
|
||||
batch.clamp(-1, 1).to(self.device, dtype=torch.bfloat16)
|
||||
).float().mean() * self.lpips_weight
|
||||
else:
|
||||
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
|
||||
Reference in New Issue
Block a user