Bug fixes on vae trainer. Allow to target params for vae training.

This commit is contained in:
Jaret Burkett
2025-07-26 09:20:22 -06:00
parent 3e14a674ac
commit 0d89c44624

View File

@@ -36,6 +36,7 @@ import lpips
import random
import traceback
from transformers import SiglipImageProcessor, SiglipVisionModel
import torch.nn.functional as F
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -56,6 +57,22 @@ def channel_dropout(x, p=0.5):
return x * mask
def sharpen_image(images: torch.Tensor) -> torch.Tensor:
# Define sharpening kernel
kernel = torch.tensor([
[ 0, -1, 0],
[-1, 5, -1],
[ 0, -1, 0]
], dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
# Repeat kernel for each channel
kernel = kernel.repeat(3, 1, 1, 1) # (out_channels, in_channels/groups, kH, kW)
# Apply the filter
sharpened = F.conv2d(images, kernel, padding=1, groups=3)
return sharpened
class TrainVAEProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -98,7 +115,8 @@ class TrainVAEProcess(BaseTrainProcess):
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
self.only_if_contains = self.get_conf('only_if_contains', None)
self.do_pooled_exits = False
self.VaeClass = AutoencoderKL
if self.vae_type == 'AutoencoderTiny':
@@ -253,6 +271,9 @@ class TrainVAEProcess(BaseTrainProcess):
def get_clip_embeddings(self, image_n1p1):
tensors_0_1 = (image_n1p1 + 1) / 2
# sharpen images
tensors_0_1 = sharpen_image(tensors_0_1)
tensors_0_1 = tensors_0_1.clamp(0, 1)
# resize if needed
@@ -494,6 +515,9 @@ class TrainVAEProcess(BaseTrainProcess):
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
latent = self.vae.encode(img, return_dict=False)[0]
if hasattr(latent, 'sample'):
latent = latent.sample()
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latent = self.vae.config['scaling_factor'] * (latent - shift)
@@ -649,24 +673,24 @@ class TrainVAEProcess(BaseTrainProcess):
train_all = 'all' in self.blocks_to_train
if train_all:
params = list(self.vae.decoder.parameters())
params = list(self.vae.decoder.named_parameters())
self.vae.decoder.requires_grad_(True)
if self.train_encoder:
# encoder
params += list(self.vae.encoder.parameters())
params += list(self.vae.encoder.named_parameters())
self.vae.encoder.requires_grad_(True)
else:
# mid_block
if train_all or 'mid_block' in self.blocks_to_train:
params += list(self.vae.decoder.mid_block.parameters())
params += list(self.vae.decoder.mid_block.named_parameters())
self.vae.decoder.mid_block.requires_grad_(True)
# up_blocks
if train_all or 'up_blocks' in self.blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters())
params += list(self.vae.decoder.up_blocks.named_parameters())
self.vae.decoder.up_blocks.requires_grad_(True)
# conv_out (single conv layer output)
if train_all or 'conv_out' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters())
params += list(self.vae.decoder.conv_out.named_parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0:
@@ -683,6 +707,15 @@ class TrainVAEProcess(BaseTrainProcess):
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
if self.only_if_contains is not None:
orig_params = params
params = []
for name, param in orig_params:
for contains in self.only_if_contains:
if contains in name:
params.append(param)
break
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
optimizer_params=self.optimizer_params)