implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar

This commit is contained in:
Jaret Burkett
2023-09-14 08:31:54 -06:00
parent 4e945917df
commit 569d7464d5
9 changed files with 173 additions and 91 deletions

View File

@@ -7,6 +7,8 @@ from typing import Union
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
@@ -21,6 +23,7 @@ from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess
@@ -28,7 +31,6 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
from toolkit.train_tools import get_torch_dtype
import gc
import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
@@ -135,6 +137,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network: Union[Network, None] = None
self.embedding: Union[Embedding, None] = None
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_embedding=self.embed_config is not None,
)
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -477,6 +489,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if it has it
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
te.enable_xformers_memory_efficient_attention()
if self.train_config.sdp:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -513,7 +529,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd)
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
@@ -547,6 +564,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.force_to(self.device_torch, dtype=dtype)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
self.network.apply_to(
text_encoder,
@@ -621,32 +639,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
else:
# set them to train or not
if self.train_config.train_unet:
self.sd.unet.requires_grad_(True)
self.sd.unet.train()
else:
self.sd.unet.requires_grad_(False)
self.sd.unet.eval()
if self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.requires_grad_(True)
te.train()
else:
self.sd.text_encoder.requires_grad_(True)
self.sd.text_encoder.train()
else:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.requires_grad_(False)
te.eval()
else:
self.sd.text_encoder.requires_grad_(False)
self.sd.text_encoder.eval()
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
params = self.get_params()
@@ -729,25 +727,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
self.lr_scheduler.step(self.step_num)
if self.embedding is not None or self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.train()
else:
self.sd.text_encoder.train()
else:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
if self.train_config.train_unet or self.embedding:
self.sd.unet.train()
else:
self.sd.unet.eval()
self.sd.set_device_state(self.train_device_state_preset)
flush()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):