mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.saving import get_lora_keymap_from_model_keymap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
@@ -338,6 +339,7 @@ class ToolkitNetworkMixin:
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
is_ssd=False,
|
||||
network_config: Optional[NetworkConfig] = None,
|
||||
is_lorm=False,
|
||||
**kwargs
|
||||
@@ -348,6 +350,7 @@ class ToolkitNetworkMixin:
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_ssd = is_ssd
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
self.is_lorm = is_lorm
|
||||
@@ -357,14 +360,25 @@ class ToolkitNetworkMixin:
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
use_weight_mapping = False
|
||||
|
||||
if self.is_ssd:
|
||||
keymap_tail = 'ssd'
|
||||
use_weight_mapping = True
|
||||
elif self.is_sdxl:
|
||||
keymap_tail = 'sdxl'
|
||||
elif self.is_v2:
|
||||
keymap_tail = 'sd2'
|
||||
else:
|
||||
keymap_tail = 'sd1'
|
||||
# todo double check this
|
||||
use_weight_mapping = True
|
||||
|
||||
# load keymap
|
||||
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||
if use_weight_mapping:
|
||||
keymap_name = f"stable_diffusion_{keymap_tail}.json"
|
||||
|
||||
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
|
||||
|
||||
keymap = None
|
||||
@@ -373,6 +387,10 @@ class ToolkitNetworkMixin:
|
||||
with open(keymap_path, 'r') as f:
|
||||
keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
if use_weight_mapping and keymap is not None:
|
||||
# get keymap from weights
|
||||
keymap = get_lora_keymap_from_model_keymap(keymap)
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
|
||||
Reference in New Issue
Block a user