diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 8a0ad3bf..c9a7ee34 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -28,6 +28,9 @@ class TrainJob(BaseJob): self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 self.logging_dir = self.get_conf('logging_dir', None) + self.writer = None + self.setup_tensorboard() + # loads the processes from the config self.load_processes(process_dict) @@ -38,3 +41,11 @@ class TrainJob(BaseJob): for process in self.process: process.run() + + def setup_tensorboard(self): + if self.logging_dir: + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter( + log_dir=self.logging_dir, + filename_suffix=f"_{self.name}" + ) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 3eb4e8c7..f5689143 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -45,19 +45,26 @@ class TrainVAEProcess(BaseTrainProcess): self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.batch_size = self.get_conf('batch_size', 1) self.resolution = self.get_conf('resolution', 256) - self.learning_rate = self.get_conf('learning_rate', 1e-4) + self.learning_rate = self.get_conf('learning_rate', 1e-6) self.sample_every = self.get_conf('sample_every', None) self.epochs = self.get_conf('epochs', None) self.max_steps = self.get_conf('max_steps', None) self.save_every = self.get_conf('save_every', None) self.dtype = self.get_conf('dtype', 'float32') self.sample_sources = self.get_conf('sample_sources', None) - self.style_weight = self.get_conf('style_weight', 1e4) - self.content_weight = self.get_conf('content_weight', 1) - self.elbo_weight = self.get_conf('elbo_weight', 1e-8) + self.log_every = self.get_conf('log_every', 100) + self.style_weight = self.get_conf('style_weight', 0) + self.content_weight = self.get_conf('content_weight', 0) + self.kld_weight = self.get_conf('kld_weight', 0) + self.mse_weight = self.get_conf('mse_weight', 1e0) + + + self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) + self.writer = self.job.writer self.torch_dtype = get_torch_dtype(self.dtype) self.save_root = os.path.join(self.training_folder, self.job.name) self.vgg_19 = None + self.progress_bar = None if self.sample_every is not None and self.sample_sources is None: raise ValueError('sample_every is specified but sample_sources is not') @@ -79,6 +86,13 @@ class TrainVAEProcess(BaseTrainProcess): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) + def print(self, message, **kwargs): + if self.progress_bar is not None: + self.progress_bar.write(message, **kwargs) + self.progress_bar.update() + else: + print(message, **kwargs) + def load_datasets(self): if self.data_loader is None: print(f"Loading datasets") @@ -104,17 +118,36 @@ class TrainVAEProcess(BaseTrainProcess): single_target=True, device=self.device) self.vgg_19.requires_grad_(False) - def get_mse_loss(self, pred, target): - loss_fn = nn.MSELoss() - loss = loss_fn(pred, target) - return loss + def get_style_loss(self): + if self.style_weight > 0: + return torch.sum(torch.stack([loss.loss for loss in self.style_losses])) + else: + return torch.tensor(0.0, device=self.device) - def get_elbo_loss(self, pred, target, mu, log_var): - # ELBO (Evidence Lower BOund) loss, aka variational lower bound - reconstruction_loss = nn.MSELoss(reduction='sum') - BCE = reconstruction_loss(pred, target) # reconstruction loss - KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence - return BCE + KLD + def get_content_loss(self): + if self.content_weight > 0: + return torch.sum(torch.stack([loss.loss for loss in self.content_losses])) + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_kld_loss(self, mu, log_var): + if self.kld_weight > 0: + # Kullback-Leibler divergence + # added here for full training (not implemented). Not needed for only decoder + # as we are not changing the distribution of the latent space + # normally it would help keep a normal distribution for latents + KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence + return KLD + else: + return torch.tensor(0.0, device=self.device) def save(self, step=None): if not os.path.exists(self.save_root): @@ -126,7 +159,6 @@ class TrainVAEProcess(BaseTrainProcess): step_num = f"_{str(step).zfill(9)}" filename = f'{self.job.name}{step_num}.safetensors' - save_path = os.path.join(self.save_root, filename) # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) @@ -148,9 +180,6 @@ class TrainVAEProcess(BaseTrainProcess): os.makedirs(sample_folder, exist_ok=True) with torch.no_grad(): - self.vae.encoder.eval() - self.vae.decoder.eval() - for i, img_url in enumerate(self.sample_sources): img = exif_transpose(Image.open(img_url)) img = img.convert('RGB') @@ -169,13 +198,9 @@ class TrainVAEProcess(BaseTrainProcess): # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() - #convert to pillow image + # convert to pillow image decoded = Image.fromarray((decoded * 255).astype(np.uint8)) - # # decoded = decoded - 0.1 - # decoded = decoded - # decoded = INVERSE_IMAGE_TRANSFORMS(decoded) - # stack input image and decoded image input_img = input_img.resize((self.resolution, self.resolution)) decoded = decoded.resize((self.resolution, self.resolution)) @@ -186,10 +211,10 @@ class TrainVAEProcess(BaseTrainProcess): step_num = '' if step is not None: - # zeropad 9 digits + # zero-pad 9 digits step_num = f"_{str(step).zfill(9)}" seconds_since_epoch = int(time.time()) - # zeropad 2 digits + # zero-pad 2 digits i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" output_img.save(os.path.join(sample_folder, filename)) @@ -208,18 +233,17 @@ class TrainVAEProcess(BaseTrainProcess): if num_steps is None or num_steps > max_epoch_steps: num_steps = max_epoch_steps - print(f"Training VAE") - print(f" - Training folder: {self.training_folder}") - print(f" - Batch size: {self.batch_size}") - print(f" - Learning rate: {self.learning_rate}") - print(f" - Epochs: {num_epochs}") - print(f" - Max steps: {self.max_steps}") + self.print(f"Training VAE") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") # load vae - print(f"Loading VAE") - print(f" - Loading VAE: {self.vae_path}") + self.print(f"Loading VAE") + self.print(f" - Loading VAE: {self.vae_path}") if self.vae is None: - # self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) # set decoder to train @@ -228,35 +252,36 @@ class TrainVAEProcess(BaseTrainProcess): self.vae.eval() self.vae.decoder.train() - blocks_to_train = [ - 'mid_block', - 'up_blocks', - ] - params = [] # only set last 2 layers to trainable for param in self.vae.decoder.parameters(): param.requires_grad = False - if 'mid_block' in blocks_to_train: + train_all = 'all' in self.blocks_to_train + + # mid_block + if train_all or 'mid_block' in self.blocks_to_train: params += list(self.vae.decoder.mid_block.parameters()) self.vae.decoder.mid_block.requires_grad_(True) - if 'up_blocks' in blocks_to_train: + # up_blocks + if train_all or 'up_blocks' in self.blocks_to_train: params += list(self.vae.decoder.up_blocks.parameters()) self.vae.decoder.up_blocks.requires_grad_(True) + # conv_out (single conv layer output) + if train_all or 'conv_out' in self.blocks_to_train: + params += list(self.vae.decoder.conv_out.parameters()) + self.vae.decoder.conv_out.requires_grad_(True) - # self.vae.decoder.train() - - self.setup_vgg19() - self.vgg_19.requires_grad_(False) - self.vgg_19.eval() - + if self.style_weight > 0 or self.content_weight > 0: + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + # todo allow other optimizers optimizer = torch.optim.Adam(params, lr=self.learning_rate) # setup scheduler - # scheduler = lr_scheduler.ConstantLR # todo allow other schedulers scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, @@ -266,7 +291,7 @@ class TrainVAEProcess(BaseTrainProcess): ) # setup tqdm progress bar - progress_bar = tqdm( + self.progress_bar = tqdm( total=num_steps, desc='Training VAE', leave=True @@ -275,6 +300,16 @@ class TrainVAEProcess(BaseTrainProcess): step = 0 # sample first self.sample() + blank_losses = OrderedDict({ + "total": [], + "style": [], + "content": [], + "mse": [], + "kl": [] + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + for epoch in range(num_epochs): if step >= num_steps: break @@ -285,8 +320,6 @@ class TrainVAEProcess(BaseTrainProcess): batch = batch.to(self.device, dtype=self.torch_dtype) # forward pass - # with torch.no_grad(): - # batch = batch + 0.1 dgd = self.vae.encode(batch).latent_dist mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() @@ -294,24 +327,18 @@ class TrainVAEProcess(BaseTrainProcess): pred = self.vae.decode(latents).sample - # pred = pred + 0.1 + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0: + stacked = torch.cat([pred, batch], dim=0) + stacked = (stacked / 2 + 0.5).clamp(0, 1) + self.vgg_19(stacked) - # loss = self.get_elbo_loss(pred, batch, mu, logvar) + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight + mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight - stacked = torch.cat([pred, batch], dim=0) - stacked = (stacked / 2 + 0.5).clamp(0, 1) - self.vgg_19(stacked) - # reduce the mean of the style_loss list - - style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses])) - content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses])) - elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar) - # elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype) - style_loss = style_loss * self.style_weight - content_loss = content_loss * self.content_weight - elbo_loss = elbo_loss * self.elbo_weight - - loss = style_loss + content_loss + elbo_loss + loss = style_loss + content_loss + kld_loss + mse_loss # Backward pass and optimization optimizer.zero_grad() @@ -322,25 +349,64 @@ class TrainVAEProcess(BaseTrainProcess): # update progress bar loss_value = loss.item() # get exponent like 3.54e-4 - loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}" + loss_string = f"loss: {loss_value:.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.kld_weight > 0: + loss_string += f" kld: {kld_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + learning_rate = optimizer.param_groups[0]['lr'] - progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") - progress_bar.set_description(f"E: {epoch} - S: {step} ") - progress_bar.update(1) + self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["kl"].append(kld_loss.item()) + + log_losses["total"].append(loss_value) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["kl"].append(kld_loss.item()) if step != 0: if self.sample_every and step % self.sample_every == 0: # print above the progress bar - print(f"Sampling at step {step}") + self.print(f"Sampling at step {step}") self.sample(step) if self.save_every and step % self.save_every == 0: # print above the progress bar - print(f"Saving at step {step}") + self.print(f"Saving at step {step}") self.save(step) + if self.log_every and step % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / len(log_losses[key]) + if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], step) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + step += 1 + # end epoch + if self.writer is not None: + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key]) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) self.save() - - pass diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 2bd45154..e72b3098 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -4,6 +4,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms from torch.utils.data import Dataset +from tqdm import tqdm class ImageDataset(Dataset): @@ -22,10 +23,17 @@ class ImageDataset(Dataset): # this might take a while print(f" - Preprocessing image dimensions") - self.file_list = [file for file in self.file_list if - int(min(Image.open(file).size) * self.scale) >= self.resolution] + new_file_list = [] + bad_count = 0 + for file in tqdm(self.file_list): + img = Image.open(file) + if int(min(img.size) * self.scale) >= self.resolution: + new_file_list.append(file) + else: + bad_count += 1 print(f" - Found {len(self.file_list)} images") + print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.path}" self.transform = transforms.Compose([