mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Default to only training mse. Did a lot of cleanup with script. Added logging via tensorboard.
This commit is contained in:
@@ -28,6 +28,9 @@ class TrainJob(BaseJob):
|
|||||||
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||||
self.logging_dir = self.get_conf('logging_dir', None)
|
self.logging_dir = self.get_conf('logging_dir', None)
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
self.setup_tensorboard()
|
||||||
|
|
||||||
# loads the processes from the config
|
# loads the processes from the config
|
||||||
self.load_processes(process_dict)
|
self.load_processes(process_dict)
|
||||||
|
|
||||||
@@ -38,3 +41,11 @@ class TrainJob(BaseJob):
|
|||||||
|
|
||||||
for process in self.process:
|
for process in self.process:
|
||||||
process.run()
|
process.run()
|
||||||
|
|
||||||
|
def setup_tensorboard(self):
|
||||||
|
if self.logging_dir:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
self.writer = SummaryWriter(
|
||||||
|
log_dir=self.logging_dir,
|
||||||
|
filename_suffix=f"_{self.name}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -45,19 +45,26 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||||
self.batch_size = self.get_conf('batch_size', 1)
|
self.batch_size = self.get_conf('batch_size', 1)
|
||||||
self.resolution = self.get_conf('resolution', 256)
|
self.resolution = self.get_conf('resolution', 256)
|
||||||
self.learning_rate = self.get_conf('learning_rate', 1e-4)
|
self.learning_rate = self.get_conf('learning_rate', 1e-6)
|
||||||
self.sample_every = self.get_conf('sample_every', None)
|
self.sample_every = self.get_conf('sample_every', None)
|
||||||
self.epochs = self.get_conf('epochs', None)
|
self.epochs = self.get_conf('epochs', None)
|
||||||
self.max_steps = self.get_conf('max_steps', None)
|
self.max_steps = self.get_conf('max_steps', None)
|
||||||
self.save_every = self.get_conf('save_every', None)
|
self.save_every = self.get_conf('save_every', None)
|
||||||
self.dtype = self.get_conf('dtype', 'float32')
|
self.dtype = self.get_conf('dtype', 'float32')
|
||||||
self.sample_sources = self.get_conf('sample_sources', None)
|
self.sample_sources = self.get_conf('sample_sources', None)
|
||||||
self.style_weight = self.get_conf('style_weight', 1e4)
|
self.log_every = self.get_conf('log_every', 100)
|
||||||
self.content_weight = self.get_conf('content_weight', 1)
|
self.style_weight = self.get_conf('style_weight', 0)
|
||||||
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
|
self.content_weight = self.get_conf('content_weight', 0)
|
||||||
|
self.kld_weight = self.get_conf('kld_weight', 0)
|
||||||
|
self.mse_weight = self.get_conf('mse_weight', 1e0)
|
||||||
|
|
||||||
|
|
||||||
|
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||||
|
self.writer = self.job.writer
|
||||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||||
self.vgg_19 = None
|
self.vgg_19 = None
|
||||||
|
self.progress_bar = None
|
||||||
|
|
||||||
if self.sample_every is not None and self.sample_sources is None:
|
if self.sample_every is not None and self.sample_sources is None:
|
||||||
raise ValueError('sample_every is specified but sample_sources is not')
|
raise ValueError('sample_every is specified but sample_sources is not')
|
||||||
@@ -79,6 +86,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
if not os.path.exists(self.save_root):
|
if not os.path.exists(self.save_root):
|
||||||
os.makedirs(self.save_root, exist_ok=True)
|
os.makedirs(self.save_root, exist_ok=True)
|
||||||
|
|
||||||
|
def print(self, message, **kwargs):
|
||||||
|
if self.progress_bar is not None:
|
||||||
|
self.progress_bar.write(message, **kwargs)
|
||||||
|
self.progress_bar.update()
|
||||||
|
else:
|
||||||
|
print(message, **kwargs)
|
||||||
|
|
||||||
def load_datasets(self):
|
def load_datasets(self):
|
||||||
if self.data_loader is None:
|
if self.data_loader is None:
|
||||||
print(f"Loading datasets")
|
print(f"Loading datasets")
|
||||||
@@ -104,17 +118,36 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
single_target=True, device=self.device)
|
single_target=True, device=self.device)
|
||||||
self.vgg_19.requires_grad_(False)
|
self.vgg_19.requires_grad_(False)
|
||||||
|
|
||||||
def get_mse_loss(self, pred, target):
|
def get_style_loss(self):
|
||||||
loss_fn = nn.MSELoss()
|
if self.style_weight > 0:
|
||||||
loss = loss_fn(pred, target)
|
return torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||||
return loss
|
else:
|
||||||
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
def get_elbo_loss(self, pred, target, mu, log_var):
|
def get_content_loss(self):
|
||||||
# ELBO (Evidence Lower BOund) loss, aka variational lower bound
|
if self.content_weight > 0:
|
||||||
reconstruction_loss = nn.MSELoss(reduction='sum')
|
return torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||||
BCE = reconstruction_loss(pred, target) # reconstruction loss
|
else:
|
||||||
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence
|
return torch.tensor(0.0, device=self.device)
|
||||||
return BCE + KLD
|
|
||||||
|
def get_mse_loss(self, pred, target):
|
||||||
|
if self.mse_weight > 0:
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
loss = loss_fn(pred, target)
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
|
def get_kld_loss(self, mu, log_var):
|
||||||
|
if self.kld_weight > 0:
|
||||||
|
# Kullback-Leibler divergence
|
||||||
|
# added here for full training (not implemented). Not needed for only decoder
|
||||||
|
# as we are not changing the distribution of the latent space
|
||||||
|
# normally it would help keep a normal distribution for latents
|
||||||
|
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence
|
||||||
|
return KLD
|
||||||
|
else:
|
||||||
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
def save(self, step=None):
|
def save(self, step=None):
|
||||||
if not os.path.exists(self.save_root):
|
if not os.path.exists(self.save_root):
|
||||||
@@ -126,7 +159,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
step_num = f"_{str(step).zfill(9)}"
|
step_num = f"_{str(step).zfill(9)}"
|
||||||
|
|
||||||
filename = f'{self.job.name}{step_num}.safetensors'
|
filename = f'{self.job.name}{step_num}.safetensors'
|
||||||
save_path = os.path.join(self.save_root, filename)
|
|
||||||
# prepare meta
|
# prepare meta
|
||||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||||
|
|
||||||
@@ -148,9 +180,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
os.makedirs(sample_folder, exist_ok=True)
|
os.makedirs(sample_folder, exist_ok=True)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.vae.encoder.eval()
|
|
||||||
self.vae.decoder.eval()
|
|
||||||
|
|
||||||
for i, img_url in enumerate(self.sample_sources):
|
for i, img_url in enumerate(self.sample_sources):
|
||||||
img = exif_transpose(Image.open(img_url))
|
img = exif_transpose(Image.open(img_url))
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
@@ -169,13 +198,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||||
|
|
||||||
#convert to pillow image
|
# convert to pillow image
|
||||||
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||||
|
|
||||||
# # decoded = decoded - 0.1
|
|
||||||
# decoded = decoded
|
|
||||||
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
|
||||||
|
|
||||||
# stack input image and decoded image
|
# stack input image and decoded image
|
||||||
input_img = input_img.resize((self.resolution, self.resolution))
|
input_img = input_img.resize((self.resolution, self.resolution))
|
||||||
decoded = decoded.resize((self.resolution, self.resolution))
|
decoded = decoded.resize((self.resolution, self.resolution))
|
||||||
@@ -186,10 +211,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
step_num = ''
|
step_num = ''
|
||||||
if step is not None:
|
if step is not None:
|
||||||
# zeropad 9 digits
|
# zero-pad 9 digits
|
||||||
step_num = f"_{str(step).zfill(9)}"
|
step_num = f"_{str(step).zfill(9)}"
|
||||||
seconds_since_epoch = int(time.time())
|
seconds_since_epoch = int(time.time())
|
||||||
# zeropad 2 digits
|
# zero-pad 2 digits
|
||||||
i_str = str(i).zfill(2)
|
i_str = str(i).zfill(2)
|
||||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||||
output_img.save(os.path.join(sample_folder, filename))
|
output_img.save(os.path.join(sample_folder, filename))
|
||||||
@@ -208,18 +233,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
if num_steps is None or num_steps > max_epoch_steps:
|
if num_steps is None or num_steps > max_epoch_steps:
|
||||||
num_steps = max_epoch_steps
|
num_steps = max_epoch_steps
|
||||||
|
|
||||||
print(f"Training VAE")
|
self.print(f"Training VAE")
|
||||||
print(f" - Training folder: {self.training_folder}")
|
self.print(f" - Training folder: {self.training_folder}")
|
||||||
print(f" - Batch size: {self.batch_size}")
|
self.print(f" - Batch size: {self.batch_size}")
|
||||||
print(f" - Learning rate: {self.learning_rate}")
|
self.print(f" - Learning rate: {self.learning_rate}")
|
||||||
print(f" - Epochs: {num_epochs}")
|
self.print(f" - Epochs: {num_epochs}")
|
||||||
print(f" - Max steps: {self.max_steps}")
|
self.print(f" - Max steps: {self.max_steps}")
|
||||||
|
|
||||||
# load vae
|
# load vae
|
||||||
print(f"Loading VAE")
|
self.print(f"Loading VAE")
|
||||||
print(f" - Loading VAE: {self.vae_path}")
|
self.print(f" - Loading VAE: {self.vae_path}")
|
||||||
if self.vae is None:
|
if self.vae is None:
|
||||||
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
|
||||||
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# set decoder to train
|
# set decoder to train
|
||||||
@@ -228,35 +252,36 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.decoder.train()
|
self.vae.decoder.train()
|
||||||
|
|
||||||
blocks_to_train = [
|
|
||||||
'mid_block',
|
|
||||||
'up_blocks',
|
|
||||||
]
|
|
||||||
|
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
# only set last 2 layers to trainable
|
# only set last 2 layers to trainable
|
||||||
for param in self.vae.decoder.parameters():
|
for param in self.vae.decoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
if 'mid_block' in blocks_to_train:
|
train_all = 'all' in self.blocks_to_train
|
||||||
|
|
||||||
|
# mid_block
|
||||||
|
if train_all or 'mid_block' in self.blocks_to_train:
|
||||||
params += list(self.vae.decoder.mid_block.parameters())
|
params += list(self.vae.decoder.mid_block.parameters())
|
||||||
self.vae.decoder.mid_block.requires_grad_(True)
|
self.vae.decoder.mid_block.requires_grad_(True)
|
||||||
if 'up_blocks' in blocks_to_train:
|
# up_blocks
|
||||||
|
if train_all or 'up_blocks' in self.blocks_to_train:
|
||||||
params += list(self.vae.decoder.up_blocks.parameters())
|
params += list(self.vae.decoder.up_blocks.parameters())
|
||||||
self.vae.decoder.up_blocks.requires_grad_(True)
|
self.vae.decoder.up_blocks.requires_grad_(True)
|
||||||
|
# conv_out (single conv layer output)
|
||||||
|
if train_all or 'conv_out' in self.blocks_to_train:
|
||||||
|
params += list(self.vae.decoder.conv_out.parameters())
|
||||||
|
self.vae.decoder.conv_out.requires_grad_(True)
|
||||||
|
|
||||||
# self.vae.decoder.train()
|
if self.style_weight > 0 or self.content_weight > 0:
|
||||||
|
self.setup_vgg19()
|
||||||
self.setup_vgg19()
|
self.vgg_19.requires_grad_(False)
|
||||||
self.vgg_19.requires_grad_(False)
|
self.vgg_19.eval()
|
||||||
self.vgg_19.eval()
|
|
||||||
|
|
||||||
|
|
||||||
|
# todo allow other optimizers
|
||||||
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
||||||
|
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
# scheduler = lr_scheduler.ConstantLR
|
|
||||||
# todo allow other schedulers
|
# todo allow other schedulers
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
@@ -266,7 +291,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# setup tqdm progress bar
|
# setup tqdm progress bar
|
||||||
progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
total=num_steps,
|
total=num_steps,
|
||||||
desc='Training VAE',
|
desc='Training VAE',
|
||||||
leave=True
|
leave=True
|
||||||
@@ -275,6 +300,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
step = 0
|
step = 0
|
||||||
# sample first
|
# sample first
|
||||||
self.sample()
|
self.sample()
|
||||||
|
blank_losses = OrderedDict({
|
||||||
|
"total": [],
|
||||||
|
"style": [],
|
||||||
|
"content": [],
|
||||||
|
"mse": [],
|
||||||
|
"kl": []
|
||||||
|
})
|
||||||
|
epoch_losses = copy.deepcopy(blank_losses)
|
||||||
|
log_losses = copy.deepcopy(blank_losses)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
if step >= num_steps:
|
if step >= num_steps:
|
||||||
break
|
break
|
||||||
@@ -285,8 +320,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
# with torch.no_grad():
|
|
||||||
# batch = batch + 0.1
|
|
||||||
dgd = self.vae.encode(batch).latent_dist
|
dgd = self.vae.encode(batch).latent_dist
|
||||||
mu, logvar = dgd.mean, dgd.logvar
|
mu, logvar = dgd.mean, dgd.logvar
|
||||||
latents = dgd.sample()
|
latents = dgd.sample()
|
||||||
@@ -294,24 +327,18 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
pred = self.vae.decode(latents).sample
|
pred = self.vae.decode(latents).sample
|
||||||
|
|
||||||
# pred = pred + 0.1
|
# Run through VGG19
|
||||||
|
if self.style_weight > 0 or self.content_weight > 0:
|
||||||
|
stacked = torch.cat([pred, batch], dim=0)
|
||||||
|
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||||
|
self.vgg_19(stacked)
|
||||||
|
|
||||||
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
style_loss = self.get_style_loss() * self.style_weight
|
||||||
|
content_loss = self.get_content_loss() * self.content_weight
|
||||||
|
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||||
|
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||||
|
|
||||||
stacked = torch.cat([pred, batch], dim=0)
|
loss = style_loss + content_loss + kld_loss + mse_loss
|
||||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
|
||||||
self.vgg_19(stacked)
|
|
||||||
# reduce the mean of the style_loss list
|
|
||||||
|
|
||||||
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
|
||||||
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
|
||||||
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
|
||||||
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
|
|
||||||
style_loss = style_loss * self.style_weight
|
|
||||||
content_loss = content_loss * self.content_weight
|
|
||||||
elbo_loss = elbo_loss * self.elbo_weight
|
|
||||||
|
|
||||||
loss = style_loss + content_loss + elbo_loss
|
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -322,25 +349,64 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
# update progress bar
|
# update progress bar
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
# get exponent like 3.54e-4
|
# get exponent like 3.54e-4
|
||||||
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
|
loss_string = f"loss: {loss_value:.2e}"
|
||||||
|
if self.content_weight > 0:
|
||||||
|
loss_string += f" cnt: {content_loss.item():.2e}"
|
||||||
|
if self.style_weight > 0:
|
||||||
|
loss_string += f" sty: {style_loss.item():.2e}"
|
||||||
|
if self.kld_weight > 0:
|
||||||
|
loss_string += f" kld: {kld_loss.item():.2e}"
|
||||||
|
if self.mse_weight > 0:
|
||||||
|
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||||
|
|
||||||
learning_rate = optimizer.param_groups[0]['lr']
|
learning_rate = optimizer.param_groups[0]['lr']
|
||||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||||
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
self.progress_bar.set_description(f"E: {epoch}")
|
||||||
progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
|
epoch_losses["total"].append(loss_value)
|
||||||
|
epoch_losses["style"].append(style_loss.item())
|
||||||
|
epoch_losses["content"].append(content_loss.item())
|
||||||
|
epoch_losses["mse"].append(mse_loss.item())
|
||||||
|
epoch_losses["kl"].append(kld_loss.item())
|
||||||
|
|
||||||
|
log_losses["total"].append(loss_value)
|
||||||
|
log_losses["style"].append(style_loss.item())
|
||||||
|
log_losses["content"].append(content_loss.item())
|
||||||
|
log_losses["mse"].append(mse_loss.item())
|
||||||
|
log_losses["kl"].append(kld_loss.item())
|
||||||
|
|
||||||
if step != 0:
|
if step != 0:
|
||||||
if self.sample_every and step % self.sample_every == 0:
|
if self.sample_every and step % self.sample_every == 0:
|
||||||
# print above the progress bar
|
# print above the progress bar
|
||||||
print(f"Sampling at step {step}")
|
self.print(f"Sampling at step {step}")
|
||||||
self.sample(step)
|
self.sample(step)
|
||||||
|
|
||||||
if self.save_every and step % self.save_every == 0:
|
if self.save_every and step % self.save_every == 0:
|
||||||
# print above the progress bar
|
# print above the progress bar
|
||||||
print(f"Saving at step {step}")
|
self.print(f"Saving at step {step}")
|
||||||
self.save(step)
|
self.save(step)
|
||||||
|
|
||||||
|
if self.log_every and step % self.log_every == 0:
|
||||||
|
# log to tensorboard
|
||||||
|
if self.writer is not None:
|
||||||
|
# get avg loss
|
||||||
|
for key in log_losses:
|
||||||
|
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||||
|
if log_losses[key] > 0:
|
||||||
|
self.writer.add_scalar(f"loss/{key}", log_losses[key], step)
|
||||||
|
# reset log losses
|
||||||
|
log_losses = copy.deepcopy(blank_losses)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
# end epoch
|
||||||
|
if self.writer is not None:
|
||||||
|
# get avg loss
|
||||||
|
for key in epoch_losses:
|
||||||
|
epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||||
|
if epoch_losses[key] > 0:
|
||||||
|
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
|
||||||
|
# reset epoch losses
|
||||||
|
epoch_losses = copy.deepcopy(blank_losses)
|
||||||
|
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from PIL import Image
|
|||||||
from PIL.ImageOps import exif_transpose
|
from PIL.ImageOps import exif_transpose
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
@@ -22,10 +23,17 @@ class ImageDataset(Dataset):
|
|||||||
|
|
||||||
# this might take a while
|
# this might take a while
|
||||||
print(f" - Preprocessing image dimensions")
|
print(f" - Preprocessing image dimensions")
|
||||||
self.file_list = [file for file in self.file_list if
|
new_file_list = []
|
||||||
int(min(Image.open(file).size) * self.scale) >= self.resolution]
|
bad_count = 0
|
||||||
|
for file in tqdm(self.file_list):
|
||||||
|
img = Image.open(file)
|
||||||
|
if int(min(img.size) * self.scale) >= self.resolution:
|
||||||
|
new_file_list.append(file)
|
||||||
|
else:
|
||||||
|
bad_count += 1
|
||||||
|
|
||||||
print(f" - Found {len(self.file_list)} images")
|
print(f" - Found {len(self.file_list)} images")
|
||||||
|
print(f" - Found {bad_count} images that are too small")
|
||||||
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
||||||
|
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
|
|||||||
Reference in New Issue
Block a user