mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-24 06:13:56 +00:00
Bug fixes on vae trainer. Allow to target params for vae training.
This commit is contained in:
@@ -36,6 +36,7 @@ import lpips
|
||||
import random
|
||||
import traceback
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
import torch.nn.functional as F
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -56,6 +57,22 @@ def channel_dropout(x, p=0.5):
|
||||
return x * mask
|
||||
|
||||
|
||||
def sharpen_image(images: torch.Tensor) -> torch.Tensor:
|
||||
# Define sharpening kernel
|
||||
kernel = torch.tensor([
|
||||
[ 0, -1, 0],
|
||||
[-1, 5, -1],
|
||||
[ 0, -1, 0]
|
||||
], dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
|
||||
|
||||
# Repeat kernel for each channel
|
||||
kernel = kernel.repeat(3, 1, 1, 1) # (out_channels, in_channels/groups, kH, kW)
|
||||
|
||||
# Apply the filter
|
||||
sharpened = F.conv2d(images, kernel, padding=1, groups=3)
|
||||
|
||||
return sharpened
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -98,7 +115,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
|
||||
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
|
||||
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
|
||||
|
||||
self.only_if_contains = self.get_conf('only_if_contains', None)
|
||||
|
||||
self.do_pooled_exits = False
|
||||
self.VaeClass = AutoencoderKL
|
||||
if self.vae_type == 'AutoencoderTiny':
|
||||
@@ -253,6 +271,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def get_clip_embeddings(self, image_n1p1):
|
||||
tensors_0_1 = (image_n1p1 + 1) / 2
|
||||
# sharpen images
|
||||
tensors_0_1 = sharpen_image(tensors_0_1)
|
||||
|
||||
tensors_0_1 = tensors_0_1.clamp(0, 1)
|
||||
|
||||
# resize if needed
|
||||
@@ -494,6 +515,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae.encode(img, return_dict=False)[0]
|
||||
|
||||
if hasattr(latent, 'sample'):
|
||||
latent = latent.sample()
|
||||
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latent = self.vae.config['scaling_factor'] * (latent - shift)
|
||||
|
||||
@@ -649,24 +673,24 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
train_all = 'all' in self.blocks_to_train
|
||||
|
||||
if train_all:
|
||||
params = list(self.vae.decoder.parameters())
|
||||
params = list(self.vae.decoder.named_parameters())
|
||||
self.vae.decoder.requires_grad_(True)
|
||||
if self.train_encoder:
|
||||
# encoder
|
||||
params += list(self.vae.encoder.parameters())
|
||||
params += list(self.vae.encoder.named_parameters())
|
||||
self.vae.encoder.requires_grad_(True)
|
||||
else:
|
||||
# mid_block
|
||||
if train_all or 'mid_block' in self.blocks_to_train:
|
||||
params += list(self.vae.decoder.mid_block.parameters())
|
||||
params += list(self.vae.decoder.mid_block.named_parameters())
|
||||
self.vae.decoder.mid_block.requires_grad_(True)
|
||||
# up_blocks
|
||||
if train_all or 'up_blocks' in self.blocks_to_train:
|
||||
params += list(self.vae.decoder.up_blocks.parameters())
|
||||
params += list(self.vae.decoder.up_blocks.named_parameters())
|
||||
self.vae.decoder.up_blocks.requires_grad_(True)
|
||||
# conv_out (single conv layer output)
|
||||
if train_all or 'conv_out' in self.blocks_to_train:
|
||||
params += list(self.vae.decoder.conv_out.parameters())
|
||||
params += list(self.vae.decoder.conv_out.named_parameters())
|
||||
self.vae.decoder.conv_out.requires_grad_(True)
|
||||
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
@@ -683,6 +707,15 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
|
||||
|
||||
if self.only_if_contains is not None:
|
||||
orig_params = params
|
||||
params = []
|
||||
for name, param in orig_params:
|
||||
for contains in self.only_if_contains:
|
||||
if contains in name:
|
||||
params.append(param)
|
||||
break
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
|
||||
Reference in New Issue
Block a user