implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar

This commit is contained in:
Jaret Burkett
2023-09-14 08:31:54 -06:00
parent 4e945917df
commit 569d7464d5
9 changed files with 173 additions and 91 deletions

View File

@@ -4,10 +4,8 @@ from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
import torch
from diffusers.utils import is_torch_version
from torch import nn
from torch.utils.checkpoint import checkpoint
import weakref
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
@@ -47,11 +45,13 @@ class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
network: Network,
call_super_init: bool = True,
**kwargs
):
if call_super_init:
super().__init__(*args, **kwargs)
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
@@ -125,13 +125,13 @@ class ToolkitModuleMixin:
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
# diffusers added scale to resnet.. not sure what it does
if self._multiplier is None:
self.set_multiplier(0.0)
if not self.network_ref().is_active:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self._multiplier.clone().detach()
multiplier = self.network_ref().torch_multiplier
lora_output_batch_size = lora_output.size(0)
multiplier_batch_size = multiplier.size(0)
@@ -328,35 +328,52 @@ class ToolkitNetworkMixin:
extra_dict = None
return extra_dict
def _update_torch_multiplier(self: Network):
# builds a tensor for fast usage in the forward pass of the network modules
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_list = []
for m in multiplier:
if isinstance(m, int) or isinstance(m, float):
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
elif isinstance(m, torch.Tensor):
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
tensor_multiplier = torch.cat(tensor_list)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
# only update if the value has changed
# it takes time to update all the multipliers, so we only do it if the value has changed
if self._multiplier == value:
return
# if we are setting a single value but have a list, keep the list if every item is the same as value
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self: Network):
if self.is_active:
for lora in self.get_all_modules():
lora.set_multiplier(self._multiplier)
else:
for lora in self.get_all_modules():
lora.set_multiplier(0)
self._update_torch_multiplier()
# called when the context manager is entered
# ie: with network:
def __enter__(self: Network):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self: Network, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self: Network, device, dtype):
self.to(device, dtype)