Bug fixes, speed improvements, compatability adjustments withdiffusers updates

This commit is contained in:
Jaret Burkett
2023-09-13 07:03:53 -06:00
parent d8d1e6fd1e
commit ae70200d3c
8 changed files with 52 additions and 11 deletions

View File

@@ -8,6 +8,7 @@ from typing import Union
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
@@ -401,10 +402,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch_size = latents.shape[0]
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
1000, device=self.device_torch
)
if self.train_config.use_progressive_denoising:
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.min_denoising_steps,
max_out=self.train_config.max_denoising_steps
))
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
elif self.train_config.use_linear_denoising:
# starts at max steps and walks down to min steps
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.max_denoising_steps - 1,
max_out=self.train_config.min_denoising_steps
))
else:
min_timestep = self.train_config.min_denoising_steps
timesteps = torch.randint(
min_timestep,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
# get noise
@@ -760,7 +786,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# setup the networks to gradient checkpointing and everything works
with torch.no_grad():
torch.cuda.empty_cache()
# torch.cuda.empty_cache()
if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
@@ -814,8 +840,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.cleanup()
# flush every 10 steps
if self.step_num % 10 == 0:
flush()
# if self.step_num % 10 == 0:
# flush()
self.progress_bar.close()
self.sample(self.step_num + 1)