mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar
This commit is contained in:
@@ -7,6 +7,8 @@ from typing import Union
|
||||
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
@@ -21,6 +23,7 @@ from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
@@ -28,7 +31,6 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
@@ -135,6 +137,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network: Union[Network, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -477,6 +489,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if it has it
|
||||
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
|
||||
te.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.sdp:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
@@ -513,7 +529,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd)
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
@@ -547,6 +564,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
|
||||
self.network.apply_to(
|
||||
text_encoder,
|
||||
@@ -621,32 +639,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
else:
|
||||
# set them to train or not
|
||||
if self.train_config.train_unet:
|
||||
self.sd.unet.requires_grad_(True)
|
||||
self.sd.unet.train()
|
||||
else:
|
||||
self.sd.unet.requires_grad_(False)
|
||||
self.sd.unet.eval()
|
||||
|
||||
if self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.requires_grad_(True)
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.requires_grad_(True)
|
||||
self.sd.text_encoder.train()
|
||||
else:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.requires_grad_(False)
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.requires_grad_(False)
|
||||
self.sd.text_encoder.eval()
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
params = self.get_params()
|
||||
|
||||
@@ -729,25 +727,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
else:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
if self.train_config.train_unet or self.embedding:
|
||||
self.sd.unet.train()
|
||||
else:
|
||||
self.sd.unet.eval()
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
|
||||
Reference in New Issue
Block a user