mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar
This commit is contained in:
@@ -29,6 +29,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
# offload it. Already cached
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
@@ -110,7 +111,5 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
# reset network multiplier
|
||||
network.multiplier = 1.0
|
||||
|
||||
return loss_dict
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import Union
|
||||
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
@@ -21,6 +23,7 @@ from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
@@ -28,7 +31,6 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
@@ -135,6 +137,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network: Union[Network, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -477,6 +489,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if it has it
|
||||
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
|
||||
te.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.sdp:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
@@ -513,7 +529,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd)
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
@@ -547,6 +564,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
|
||||
self.network.apply_to(
|
||||
text_encoder,
|
||||
@@ -621,32 +639,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
else:
|
||||
# set them to train or not
|
||||
if self.train_config.train_unet:
|
||||
self.sd.unet.requires_grad_(True)
|
||||
self.sd.unet.train()
|
||||
else:
|
||||
self.sd.unet.requires_grad_(False)
|
||||
self.sd.unet.eval()
|
||||
|
||||
if self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.requires_grad_(True)
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.requires_grad_(True)
|
||||
self.sd.text_encoder.train()
|
||||
else:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.requires_grad_(False)
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.requires_grad_(False)
|
||||
self.sd.text_encoder.eval()
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
params = self.get_params()
|
||||
|
||||
@@ -729,25 +727,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
else:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
if self.train_config.train_unet or self.embedding:
|
||||
self.sd.unet.train()
|
||||
else:
|
||||
self.sd.unet.eval()
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
|
||||
@@ -85,6 +85,7 @@ class TrainConfig:
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
@@ -46,11 +45,12 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
parent=None,
|
||||
**kwargs
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
super().__init__(network=network)
|
||||
self.lora_name = lora_name
|
||||
self.scalar = torch.tensor(1.0)
|
||||
|
||||
@@ -150,7 +150,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
5. modules_dimとmodules_alphaを指定 (推論用)
|
||||
"""
|
||||
# call the parent of the parent we are replacing (LoRANetwork) init
|
||||
super(LoRANetwork, self).__init__()
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
@@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
self.torch_multiplier = None
|
||||
# triggers the state updates
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
@@ -258,6 +259,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
network=self,
|
||||
parent=module,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -29,6 +29,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
lora_dim=4, alpha=1,
|
||||
dropout=0., rank_dropout=0., module_dropout=0.,
|
||||
use_cp=False,
|
||||
network: 'LycorisSpecialNetwork' = None,
|
||||
parent=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -36,7 +37,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
# call super of super
|
||||
torch.nn.Module.__init__(self)
|
||||
# call super of
|
||||
super().__init__(call_super_init=False)
|
||||
super().__init__(call_super_init=False, network=network)
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
@@ -170,6 +171,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
if module_dropout is None:
|
||||
module_dropout = 0
|
||||
|
||||
self.torch_multiplier = None
|
||||
# triggers a tensor update
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
@@ -229,6 +232,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
@@ -240,6 +244,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
@@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
@@ -271,6 +277,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
network=self,
|
||||
**kwargs
|
||||
)
|
||||
elif module.__class__.__name__ == 'Conv2d':
|
||||
@@ -281,6 +288,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
@@ -290,6 +298,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -4,10 +4,8 @@ from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
|
||||
|
||||
import torch
|
||||
from diffusers.utils import is_torch_version
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import weakref
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
|
||||
@@ -47,11 +45,13 @@ class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
network: Network,
|
||||
call_super_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
@@ -125,13 +125,13 @@ class ToolkitModuleMixin:
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
# diffusers added scale to resnet.. not sure what it does
|
||||
if self._multiplier is None:
|
||||
self.set_multiplier(0.0)
|
||||
if not self.network_ref().is_active:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self._multiplier.clone().detach()
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
lora_output_batch_size = lora_output.size(0)
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
@@ -328,35 +328,52 @@ class ToolkitNetworkMixin:
|
||||
extra_dict = None
|
||||
return extra_dict
|
||||
|
||||
def _update_torch_multiplier(self: Network):
|
||||
# builds a tensor for fast usage in the forward pass of the network modules
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, list):
|
||||
tensor_list = []
|
||||
for m in multiplier:
|
||||
if isinstance(m, int) or isinstance(m, float):
|
||||
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
|
||||
elif isinstance(m, torch.Tensor):
|
||||
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
|
||||
tensor_multiplier = torch.cat(tensor_list)
|
||||
elif isinstance(multiplier, torch.Tensor):
|
||||
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
|
||||
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
# only update if the value has changed
|
||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
# if we are setting a single value but have a list, keep the list if every item is the same as value
|
||||
self._multiplier = value
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def _update_lora_multiplier(self: Network):
|
||||
if self.is_active:
|
||||
for lora in self.get_all_modules():
|
||||
lora.set_multiplier(self._multiplier)
|
||||
else:
|
||||
for lora in self.get_all_modules():
|
||||
lora.set_multiplier(0)
|
||||
self._update_torch_multiplier()
|
||||
|
||||
# called when the context manager is entered
|
||||
# ie: with network:
|
||||
def __enter__(self: Network):
|
||||
self.is_active = True
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def __exit__(self: Network, exc_type, exc_value, tb):
|
||||
self.is_active = False
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def force_to(self: Network, device, dtype):
|
||||
self.to(device, dtype)
|
||||
|
||||
@@ -6,6 +6,7 @@ class ToolkitProgressBar(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.paused = False
|
||||
self.last_time = self._time()
|
||||
|
||||
def pause(self):
|
||||
if not self.paused:
|
||||
@@ -15,7 +16,9 @@ class ToolkitProgressBar(tqdm):
|
||||
def unpause(self):
|
||||
if self.paused:
|
||||
self.paused = False
|
||||
self.start_t += self._time() - self.last_time
|
||||
cur_t = self._time()
|
||||
self.start_t += cur_t - self.last_time
|
||||
self.last_print_t = cur_t
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
if not self.paused:
|
||||
|
||||
59
toolkit/sd_device_states_presets.py
Normal file
59
toolkit/sd_device_states_presets.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import copy
|
||||
|
||||
empty_preset = {
|
||||
'vae': {
|
||||
'training': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'unet': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'text_encoder': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_train_sd_device_state_preset (
|
||||
device: torch.DeviceObjType,
|
||||
train_unet: bool = False,
|
||||
train_text_encoder: bool = False,
|
||||
cached_latents: bool = False,
|
||||
train_lora: bool = False,
|
||||
train_embedding: bool = False,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
if not cached_latents:
|
||||
preset['vae']['device'] = device
|
||||
|
||||
if train_unet:
|
||||
preset['unet']['training'] = True
|
||||
preset['unet']['requires_grad'] = True
|
||||
preset['unet']['device'] = device
|
||||
else:
|
||||
preset['unet']['device'] = device
|
||||
|
||||
if train_text_encoder:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['device'] = device
|
||||
else:
|
||||
preset['text_encoder']['device'] = device
|
||||
|
||||
if train_embedding:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
|
||||
if train_lora:
|
||||
preset['text_encoder']['requires_grad'] = False
|
||||
preset['unet']['requires_grad'] = False
|
||||
|
||||
return preset
|
||||
@@ -24,7 +24,6 @@ from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
@@ -48,7 +47,7 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
DeviceStatePreset = Literal['cache_latents']
|
||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
@@ -111,7 +110,7 @@ class StableDiffusion:
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler
|
||||
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
@@ -247,7 +246,7 @@ class StableDiffusion:
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
self.unet.to(self.device_torch, dtype=dtype)
|
||||
@@ -275,26 +274,12 @@ class StableDiffusion:
|
||||
network.is_normalizing = False
|
||||
|
||||
self.save_device_state()
|
||||
self.set_device_state_preset('generate')
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
# handle sdxl text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
|
||||
encoder.to(self.device_torch)
|
||||
encoder.eval()
|
||||
else:
|
||||
self.text_encoder.to(self.device_torch)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.vae.to(self.device_torch)
|
||||
self.vae.eval()
|
||||
self.unet.to(self.device_torch)
|
||||
self.unet.eval()
|
||||
flush()
|
||||
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||
@@ -346,6 +331,7 @@ class StableDiffusion:
|
||||
start_multiplier = self.network.multiplier
|
||||
|
||||
pipeline.to(self.device_torch)
|
||||
|
||||
with network:
|
||||
with torch.no_grad():
|
||||
if self.network is not None:
|
||||
@@ -876,6 +862,7 @@ class StableDiffusion:
|
||||
'unet': {
|
||||
'training': self.unet.training,
|
||||
'device': self.unet.device,
|
||||
'requires_grad': self.unet.conv_in.weight.requires_grad,
|
||||
},
|
||||
}
|
||||
if isinstance(self.text_encoder, list):
|
||||
@@ -884,11 +871,14 @@ class StableDiffusion:
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
# todo there has to be a better way to do this
|
||||
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
})
|
||||
else:
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
'device': self.text_encoder.device,
|
||||
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
@@ -910,19 +900,33 @@ class StableDiffusion:
|
||||
else:
|
||||
self.unet.eval()
|
||||
self.unet.to(state['unet']['device'])
|
||||
if state['unet']['requires_grad']:
|
||||
self.unet.requires_grad_(True)
|
||||
else:
|
||||
self.unet.requires_grad_(False)
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
if state['text_encoder'][i]['training']:
|
||||
encoder.train()
|
||||
if isinstance(state['text_encoder'], list):
|
||||
if state['text_encoder'][i]['training']:
|
||||
encoder.train()
|
||||
else:
|
||||
encoder.eval()
|
||||
encoder.to(state['text_encoder'][i]['device'])
|
||||
encoder.requires_grad_(state['text_encoder'][i]['requires_grad'])
|
||||
else:
|
||||
encoder.eval()
|
||||
encoder.to(state['text_encoder'][i]['device'])
|
||||
if state['text_encoder']['training']:
|
||||
encoder.train()
|
||||
else:
|
||||
encoder.eval()
|
||||
encoder.to(state['text_encoder']['device'])
|
||||
encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
||||
else:
|
||||
if state['text_encoder']['training']:
|
||||
self.text_encoder.train()
|
||||
else:
|
||||
self.text_encoder.eval()
|
||||
self.text_encoder.to(state['text_encoder']['device'])
|
||||
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
@@ -935,18 +939,22 @@ class StableDiffusion:
|
||||
training_modules = []
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder']
|
||||
|
||||
state = {}
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
||||
'requires_grad': 'vae' in training_modules,
|
||||
}
|
||||
|
||||
# unet
|
||||
state['unet'] = {
|
||||
'training': 'unet' in training_modules,
|
||||
'device': self.device_torch if 'unet' in active_modules else 'cpu',
|
||||
'requires_grad': 'unet' in training_modules,
|
||||
}
|
||||
|
||||
# text encoder
|
||||
@@ -956,11 +964,13 @@ class StableDiffusion:
|
||||
state['text_encoder'].append({
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
})
|
||||
else:
|
||||
state['text_encoder'] = {
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
}
|
||||
|
||||
self.set_device_state(state)
|
||||
|
||||
Reference in New Issue
Block a user