diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 27daefab..04be4a1e 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -36,6 +36,7 @@ import lpips import random import traceback from transformers import SiglipImageProcessor, SiglipVisionModel +import torch.nn.functional as F IMAGE_TRANSFORMS = transforms.Compose( [ @@ -56,6 +57,22 @@ def channel_dropout(x, p=0.5): return x * mask +def sharpen_image(images: torch.Tensor) -> torch.Tensor: + # Define sharpening kernel + kernel = torch.tensor([ + [ 0, -1, 0], + [-1, 5, -1], + [ 0, -1, 0] + ], dtype=images.dtype, device=images.device).view(1, 1, 3, 3) + + # Repeat kernel for each channel + kernel = kernel.repeat(3, 1, 1, 1) # (out_channels, in_channels/groups, kH, kW) + + # Apply the filter + sharpened = F.conv2d(images, kernel, padding=1, groups=3) + + return sharpened + class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -98,7 +115,8 @@ class TrainVAEProcess(BaseTrainProcess): self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny - + self.only_if_contains = self.get_conf('only_if_contains', None) + self.do_pooled_exits = False self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderTiny': @@ -253,6 +271,9 @@ class TrainVAEProcess(BaseTrainProcess): def get_clip_embeddings(self, image_n1p1): tensors_0_1 = (image_n1p1 + 1) / 2 + # sharpen images + tensors_0_1 = sharpen_image(tensors_0_1) + tensors_0_1 = tensors_0_1.clamp(0, 1) # resize if needed @@ -494,6 +515,9 @@ class TrainVAEProcess(BaseTrainProcess): target_latent = target_latent.to(self.device, dtype=self.torch_dtype) latent = self.vae.encode(img, return_dict=False)[0] + if hasattr(latent, 'sample'): + latent = latent.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 latent = self.vae.config['scaling_factor'] * (latent - shift) @@ -649,24 +673,24 @@ class TrainVAEProcess(BaseTrainProcess): train_all = 'all' in self.blocks_to_train if train_all: - params = list(self.vae.decoder.parameters()) + params = list(self.vae.decoder.named_parameters()) self.vae.decoder.requires_grad_(True) if self.train_encoder: # encoder - params += list(self.vae.encoder.parameters()) + params += list(self.vae.encoder.named_parameters()) self.vae.encoder.requires_grad_(True) else: # mid_block if train_all or 'mid_block' in self.blocks_to_train: - params += list(self.vae.decoder.mid_block.parameters()) + params += list(self.vae.decoder.mid_block.named_parameters()) self.vae.decoder.mid_block.requires_grad_(True) # up_blocks if train_all or 'up_blocks' in self.blocks_to_train: - params += list(self.vae.decoder.up_blocks.parameters()) + params += list(self.vae.decoder.up_blocks.named_parameters()) self.vae.decoder.up_blocks.requires_grad_(True) # conv_out (single conv layer output) if train_all or 'conv_out' in self.blocks_to_train: - params += list(self.vae.decoder.conv_out.parameters()) + params += list(self.vae.decoder.conv_out.named_parameters()) self.vae.decoder.conv_out.requires_grad_(True) if self.style_weight > 0 or self.content_weight > 0: @@ -683,6 +707,15 @@ class TrainVAEProcess(BaseTrainProcess): if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16) + + if self.only_if_contains is not None: + orig_params = params + params = [] + for name, param in orig_params: + for contains in self.only_if_contains: + if contains in name: + params.append(param) + break optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, optimizer_params=self.optimizer_params)