mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Bug fixes, speed improvements, compatability adjustments withdiffusers updates
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Union
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
@@ -401,10 +402,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.use_progressive_denoising:
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.min_denoising_steps,
|
||||
max_out=self.train_config.max_denoising_steps
|
||||
))
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
elif self.train_config.use_linear_denoising:
|
||||
# starts at max steps and walks down to min steps
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.max_denoising_steps - 1,
|
||||
max_out=self.train_config.min_denoising_steps
|
||||
))
|
||||
else:
|
||||
min_timestep = self.train_config.min_denoising_steps
|
||||
|
||||
timesteps = torch.randint(
|
||||
min_timestep,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
@@ -760,7 +786,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cuda.empty_cache()
|
||||
# torch.cuda.empty_cache()
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
@@ -814,8 +840,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
if self.step_num % 10 == 0:
|
||||
flush()
|
||||
# if self.step_num % 10 == 0:
|
||||
# flush()
|
||||
|
||||
self.progress_bar.close()
|
||||
self.sample(self.step_num + 1)
|
||||
|
||||
Reference in New Issue
Block a user