mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar
This commit is contained in:
@@ -4,10 +4,8 @@ from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
|
||||
|
||||
import torch
|
||||
from diffusers.utils import is_torch_version
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import weakref
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
|
||||
@@ -47,11 +45,13 @@ class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
network: Network,
|
||||
call_super_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
@@ -125,13 +125,13 @@ class ToolkitModuleMixin:
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
# diffusers added scale to resnet.. not sure what it does
|
||||
if self._multiplier is None:
|
||||
self.set_multiplier(0.0)
|
||||
if not self.network_ref().is_active:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self._multiplier.clone().detach()
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
lora_output_batch_size = lora_output.size(0)
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
@@ -328,35 +328,52 @@ class ToolkitNetworkMixin:
|
||||
extra_dict = None
|
||||
return extra_dict
|
||||
|
||||
def _update_torch_multiplier(self: Network):
|
||||
# builds a tensor for fast usage in the forward pass of the network modules
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, list):
|
||||
tensor_list = []
|
||||
for m in multiplier:
|
||||
if isinstance(m, int) or isinstance(m, float):
|
||||
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
|
||||
elif isinstance(m, torch.Tensor):
|
||||
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
|
||||
tensor_multiplier = torch.cat(tensor_list)
|
||||
elif isinstance(multiplier, torch.Tensor):
|
||||
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
|
||||
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
# only update if the value has changed
|
||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
# if we are setting a single value but have a list, keep the list if every item is the same as value
|
||||
self._multiplier = value
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def _update_lora_multiplier(self: Network):
|
||||
if self.is_active:
|
||||
for lora in self.get_all_modules():
|
||||
lora.set_multiplier(self._multiplier)
|
||||
else:
|
||||
for lora in self.get_all_modules():
|
||||
lora.set_multiplier(0)
|
||||
self._update_torch_multiplier()
|
||||
|
||||
# called when the context manager is entered
|
||||
# ie: with network:
|
||||
def __enter__(self: Network):
|
||||
self.is_active = True
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def __exit__(self: Network, exc_type, exc_value, tb):
|
||||
self.is_active = False
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def force_to(self: Network, device, dtype):
|
||||
self.to(device, dtype)
|
||||
|
||||
Reference in New Issue
Block a user