implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar

This commit is contained in:
Jaret Burkett
2023-09-14 08:31:54 -06:00
parent 4e945917df
commit 569d7464d5
9 changed files with 173 additions and 91 deletions

View File

@@ -4,7 +4,6 @@ import os
import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
@@ -46,11 +45,12 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
network: 'LoRASpecialNetwork' = None,
parent=None,
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
super().__init__(network=network)
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
@@ -150,7 +150,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
5. modules_dimとmodules_alphaを指定 (推論用)
"""
# call the parent of the parent we are replacing (LoRANetwork) init
super(LoRANetwork, self).__init__()
torch.nn.Module.__init__(self)
self.lora_dim = lora_dim
self.alpha = alpha
@@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.torch_multiplier = None
# triggers the state updates
self.multiplier = multiplier
self.is_sdxl = is_sdxl
@@ -258,6 +259,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network=self,
parent=module,
)
loras.append(lora)