mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Style and content loss working
This commit is contained in:
@@ -15,7 +15,9 @@ from jobs.process import BaseTrainProcess
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
@@ -27,15 +29,9 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
]
|
||||
)
|
||||
|
||||
INVERSE_IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=[-0.5/0.5],
|
||||
std=[1/0.5]
|
||||
),
|
||||
transforms.ToPILImage(),
|
||||
]
|
||||
)
|
||||
|
||||
def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
@@ -56,8 +52,12 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.save_every = self.get_conf('save_every', None)
|
||||
self.dtype = self.get_conf('dtype', 'float32')
|
||||
self.sample_sources = self.get_conf('sample_sources', None)
|
||||
self.style_weight = self.get_conf('style_weight', 1e4)
|
||||
self.content_weight = self.get_conf('content_weight', 1)
|
||||
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.vgg_19 = None
|
||||
|
||||
if self.sample_every is not None and self.sample_sources is None:
|
||||
raise ValueError('sample_every is specified but sample_sources is not')
|
||||
@@ -66,7 +66,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
raise ValueError('epochs or max_steps must be specified')
|
||||
|
||||
self.data_loaders = []
|
||||
datasets = []
|
||||
# check datasets
|
||||
assert isinstance(self.datasets_objects, list)
|
||||
for dataset in self.datasets_objects:
|
||||
@@ -95,10 +94,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=6
|
||||
)
|
||||
|
||||
def get_loss(self, pred, target):
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
|
||||
single_target=True, device=self.device)
|
||||
self.vgg_19.requires_grad_(False)
|
||||
|
||||
def get_mse_loss(self, pred, target):
|
||||
loss_fn = nn.MSELoss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
@@ -157,8 +163,18 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
input_img = img
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
decoded = self.vae(img).sample.squeeze(0)
|
||||
decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
||||
img = img
|
||||
decoded = self.vae(img).sample
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
#convert to pillow image
|
||||
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||
|
||||
# # decoded = decoded - 0.1
|
||||
# decoded = decoded
|
||||
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
||||
|
||||
# stack input image and decoded image
|
||||
input_img = input_img.resize((self.resolution, self.resolution))
|
||||
@@ -177,7 +193,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
output_img.save(os.path.join(sample_folder, filename))
|
||||
self.vae.decoder.train()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -204,19 +219,41 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
print(f"Loading VAE")
|
||||
print(f" - Loading VAE: {self.vae_path}")
|
||||
if self.vae is None:
|
||||
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||
|
||||
# set decoder to train
|
||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
|
||||
self.vae.decoder.requires_grad_(True)
|
||||
self.vae.decoder.train()
|
||||
|
||||
parameters = self.vae.decoder.parameters()
|
||||
blocks_to_train = [
|
||||
'mid_block',
|
||||
'up_blocks',
|
||||
]
|
||||
|
||||
optimizer = torch.optim.Adam(parameters, lr=self.learning_rate)
|
||||
params = []
|
||||
|
||||
# only set last 2 layers to trainable
|
||||
for param in self.vae.decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if 'mid_block' in blocks_to_train:
|
||||
params += list(self.vae.decoder.mid_block.parameters())
|
||||
self.vae.decoder.mid_block.requires_grad_(True)
|
||||
if 'up_blocks' in blocks_to_train:
|
||||
params += list(self.vae.decoder.up_blocks.parameters())
|
||||
self.vae.decoder.up_blocks.requires_grad_(True)
|
||||
|
||||
# self.vae.decoder.train()
|
||||
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
|
||||
|
||||
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
||||
|
||||
# setup scheduler
|
||||
# scheduler = lr_scheduler.ConstantLR
|
||||
@@ -249,6 +286,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
# forward pass
|
||||
# with torch.no_grad():
|
||||
# batch = batch + 0.1
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
@@ -256,7 +294,24 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
# pred = pred + 0.1
|
||||
|
||||
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
# reduce the mean of the style_loss list
|
||||
|
||||
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
|
||||
style_loss = style_loss * self.style_weight
|
||||
content_loss = content_loss * self.content_weight
|
||||
elbo_loss = elbo_loss * self.elbo_weight
|
||||
|
||||
loss = style_loss + content_loss + elbo_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -267,9 +322,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# update progress bar
|
||||
loss_value = loss.item()
|
||||
# get exponent like 3.54e-4
|
||||
loss_string = f"{loss_value:.2e}"
|
||||
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}")
|
||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -279,7 +334,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
print(f"Sampling at step {step}")
|
||||
self.sample(step)
|
||||
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
print(f"Saving at step {step}")
|
||||
self.save(step)
|
||||
|
||||
Reference in New Issue
Block a user