mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes, speed improvements, compatability adjustments withdiffusers updates
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Union
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
@@ -401,10 +402,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.use_progressive_denoising:
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.min_denoising_steps,
|
||||
max_out=self.train_config.max_denoising_steps
|
||||
))
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
elif self.train_config.use_linear_denoising:
|
||||
# starts at max steps and walks down to min steps
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.max_denoising_steps - 1,
|
||||
max_out=self.train_config.min_denoising_steps
|
||||
))
|
||||
else:
|
||||
min_timestep = self.train_config.min_denoising_steps
|
||||
|
||||
timesteps = torch.randint(
|
||||
min_timestep,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
@@ -760,7 +786,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cuda.empty_cache()
|
||||
# torch.cuda.empty_cache()
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
@@ -814,8 +840,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
if self.step_num % 10 == 0:
|
||||
flush()
|
||||
# if self.step_num % 10 == 0:
|
||||
# flush()
|
||||
|
||||
self.progress_bar.close()
|
||||
self.sample(self.step_num + 1)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
@@ -28,6 +30,14 @@ class BaseTrainProcess(BaseProcess):
|
||||
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
self.progress_bar: 'tqdm' = None
|
||||
|
||||
self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
|
||||
# if training seed is set, use it
|
||||
if self.training_seed is not None:
|
||||
torch.manual_seed(self.training_seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(self.training_seed)
|
||||
random.seed(self.training_seed)
|
||||
|
||||
self.progress_bar = None
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder',
|
||||
|
||||
2
run.py
2
run.py
@@ -4,7 +4,7 @@ from typing import Union, OrderedDict
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
import toolkit.cuda_malloc
|
||||
# import toolkit.cuda_malloc
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
|
||||
|
||||
@@ -77,7 +77,10 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
|
||||
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
|
||||
@@ -418,7 +418,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
@lru_cache(maxsize=300)
|
||||
def __getitem__(self, item):
|
||||
if self.dataset_config.buckets:
|
||||
# for buckets we collate ourselves for now
|
||||
|
||||
@@ -53,7 +53,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
def cleanup(self):
|
||||
del self.tensor
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
|
||||
|
||||
|
||||
@@ -334,6 +334,9 @@ class ToolkitNetworkMixin:
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
# only update if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
self._multiplier = value
|
||||
self._update_lora_multiplier()
|
||||
|
||||
|
||||
@@ -138,9 +138,9 @@ class StableDiffusion:
|
||||
self.noise_scheduler = scheduler
|
||||
|
||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
# self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
# self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
|
||||
Reference in New Issue
Block a user