diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 48b2a638..f9b810fe 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -8,6 +8,7 @@ from typing import Union # from lycoris.config import PRESET from torch.utils.data import DataLoader +from toolkit.basic import value_map from toolkit.data_loader import get_dataloader_from_datasets from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding @@ -401,10 +402,35 @@ class BaseSDTrainProcess(BaseTrainProcess): batch_size = latents.shape[0] self.sd.noise_scheduler.set_timesteps( - self.train_config.max_denoising_steps, device=self.device_torch + 1000, device=self.device_torch ) + if self.train_config.use_progressive_denoising: + min_timestep = int(value_map( + self.step_num, + min_in=0, + max_in=self.train_config.max_denoising_steps, + min_out=self.train_config.min_denoising_steps, + max_out=self.train_config.max_denoising_steps + )) - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) + elif self.train_config.use_linear_denoising: + # starts at max steps and walks down to min steps + min_timestep = int(value_map( + self.step_num, + min_in=0, + max_in=self.train_config.max_denoising_steps, + min_out=self.train_config.max_denoising_steps - 1, + max_out=self.train_config.min_denoising_steps + )) + else: + min_timestep = self.train_config.min_denoising_steps + + timesteps = torch.randint( + min_timestep, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) timesteps = timesteps.long() # get noise @@ -760,7 +786,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # setup the networks to gradient checkpointing and everything works with torch.no_grad(): - torch.cuda.empty_cache() + # torch.cuda.empty_cache() if self.train_config.optimizer.lower().startswith('dadaptation') or \ self.train_config.optimizer.lower().startswith('prodigy'): learning_rate = ( @@ -814,8 +840,8 @@ class BaseSDTrainProcess(BaseTrainProcess): batch.cleanup() # flush every 10 steps - if self.step_num % 10 == 0: - flush() + # if self.step_num % 10 == 0: + # flush() self.progress_bar.close() self.sample(self.step_num + 1) diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index 169eff98..d1885de2 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -1,8 +1,10 @@ +import random from datetime import datetime import os from collections import OrderedDict from typing import TYPE_CHECKING, Union +import torch import yaml from jobs.process.BaseProcess import BaseProcess @@ -28,6 +30,14 @@ class BaseTrainProcess(BaseProcess): self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] self.progress_bar: 'tqdm' = None + self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None) + # if training seed is set, use it + if self.training_seed is not None: + torch.manual_seed(self.training_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.training_seed) + random.seed(self.training_seed) + self.progress_bar = None self.writer = None self.training_folder = self.get_conf('training_folder', diff --git a/run.py b/run.py index 81c46653..54728c3d 100644 --- a/run.py +++ b/run.py @@ -4,7 +4,7 @@ from typing import Union, OrderedDict sys.path.insert(0, os.getcwd()) # must come before ANY torch or fastai imports -import toolkit.cuda_malloc +# import toolkit.cuda_malloc import argparse from toolkit.job import get_job diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bb1b321e..6c7e7895 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -77,7 +77,10 @@ class TrainConfig: self.optimizer_params = kwargs.get('optimizer_params', {}) self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {}) + self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0) self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000) + self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False) + self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False) self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 1be02c67..8fe7897e 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -418,7 +418,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): file_item.load_caption(self.caption_dict) return file_item - @lru_cache(maxsize=300) def __getitem__(self, item): if self.dataset_config.buckets: # for buckets we collate ourselves for now diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 68b0c76f..59edf2c9 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag self.tensor: Union[torch.Tensor, None] = None def cleanup(self): - del self.tensor + self.tensor = None self.cleanup_latent() diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 08b316f5..8b5b8ca6 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -334,6 +334,9 @@ class ToolkitNetworkMixin: @multiplier.setter def multiplier(self, value: Union[float, List[float]]): + # only update if the value has changed + if self._multiplier == value: + return self._multiplier = value self._update_lora_multiplier() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 3ba7638e..bd303eee 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -138,9 +138,9 @@ class StableDiffusion: self.noise_scheduler = scheduler # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why - self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) - self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) - self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) model_path = self.model_config.name_or_path if 'civitai.com' in self.model_config.name_or_path: