Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig
from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
from toolkit.saving import get_lora_keymap_from_model_keymap
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
@@ -338,6 +339,7 @@ class ToolkitNetworkMixin:
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
is_ssd=False,
network_config: Optional[NetworkConfig] = None,
is_lorm=False,
**kwargs
@@ -348,6 +350,7 @@ class ToolkitNetworkMixin:
self._multiplier: float = 1.0
self.is_active: bool = False
self.is_sdxl = is_sdxl
self.is_ssd = is_ssd
self.is_v2 = is_v2
self.is_merged_in = False
self.is_lorm = is_lorm
@@ -357,14 +360,25 @@ class ToolkitNetworkMixin:
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
if self.is_sdxl:
use_weight_mapping = False
if self.is_ssd:
keymap_tail = 'ssd'
use_weight_mapping = True
elif self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# todo double check this
use_weight_mapping = True
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
if use_weight_mapping:
keymap_name = f"stable_diffusion_{keymap_tail}.json"
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
keymap = None
@@ -373,6 +387,10 @@ class ToolkitNetworkMixin:
with open(keymap_path, 'r') as f:
keymap = json.load(f)['ldm_diffusers_keymap']
if use_weight_mapping and keymap is not None:
# get keymap from weights
keymap = get_lora_keymap_from_model_keymap(keymap)
return keymap
def save_weights(