mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Model rescale and prepared a release upgrade
This commit is contained in:
@@ -99,5 +99,5 @@ class SliderConfig:
|
||||
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
||||
self.anchors: List[SliderConfigAnchors] = anchors
|
||||
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
||||
self.prompt_file: str = kwargs.get('prompt_file', '')
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', '')
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
|
||||
@@ -13,6 +13,9 @@ def get_job(config_path, name=None):
|
||||
if job == 'train':
|
||||
from jobs import TrainJob
|
||||
return TrainJob(config)
|
||||
if job == 'mod':
|
||||
from jobs import ModJob
|
||||
return ModJob(config)
|
||||
|
||||
# elif job == 'train':
|
||||
# from jobs import TrainJob
|
||||
|
||||
@@ -6,12 +6,14 @@
|
||||
import os
|
||||
import math
|
||||
from typing import Optional, List, Type, Set, Literal
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
||||
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
|
||||
@@ -31,7 +33,7 @@ TRAINING_METHODS = Literal[
|
||||
"innoxattn", # train all layers except self attention layers
|
||||
"selfattn", # ESD-u, train only self attention layers
|
||||
"xattn", # ESD-x, train only x attention layers
|
||||
"full", # train all layers
|
||||
"full", # train all layers
|
||||
# "notime",
|
||||
# "xlayer",
|
||||
# "outxattn",
|
||||
@@ -48,12 +50,12 @@ class LoRAModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
@@ -102,19 +104,19 @@ class LoRAModule(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
)
|
||||
|
||||
|
||||
class LoRANetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
rank: int = 4,
|
||||
multiplier: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
train_method: TRAINING_METHODS = "full",
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
rank: int = 4,
|
||||
multiplier: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
train_method: TRAINING_METHODS = "full",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -140,7 +142,7 @@ class LoRANetwork(nn.Module):
|
||||
lora_names = set()
|
||||
for lora in self.unet_loras:
|
||||
assert (
|
||||
lora.lora_name not in lora_names
|
||||
lora.lora_name not in lora_names
|
||||
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
||||
lora_names.add(lora.lora_name)
|
||||
|
||||
@@ -157,13 +159,13 @@ class LoRANetwork(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_modules(
|
||||
self,
|
||||
prefix: str,
|
||||
root_module: nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
rank: int,
|
||||
multiplier: float,
|
||||
train_method: TRAINING_METHODS,
|
||||
self,
|
||||
prefix: str,
|
||||
root_module: nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
rank: int,
|
||||
multiplier: float,
|
||||
train_method: TRAINING_METHODS,
|
||||
) -> list:
|
||||
loras = []
|
||||
|
||||
@@ -212,6 +214,8 @@ class LoRANetwork(nn.Module):
|
||||
|
||||
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
||||
state_dict = self.state_dict()
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
@@ -221,9 +225,10 @@ class LoRANetwork(nn.Module):
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.startswith("lora"):
|
||||
# lora以外除外
|
||||
# remove any not lora
|
||||
del state_dict[key]
|
||||
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from io import BytesIO
|
||||
|
||||
import safetensors
|
||||
from safetensors import safe_open
|
||||
|
||||
from info import software_meta
|
||||
from toolkit.train_tools import addnet_hash_legacy
|
||||
from toolkit.train_tools import addnet_hash_safetensors
|
||||
|
||||
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict:
|
||||
# stringify the meta and reparse OrderedDict to replace [name] with name
|
||||
meta_string = json.dumps(meta)
|
||||
if name is not None:
|
||||
meta_string = meta_string.replace("[name]", name)
|
||||
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
|
||||
save_meta["software"] = software_meta
|
||||
if add_software_info:
|
||||
save_meta["software"] = software_meta
|
||||
# safetensors can only be one level deep
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
@@ -21,6 +26,46 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
return save_meta
|
||||
|
||||
|
||||
def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict:
|
||||
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||
save time on indexing the model later."""
|
||||
|
||||
# Because writing user metadata to the file can change the result of
|
||||
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||
# calculating the hash, as they are meant to be immutable
|
||||
metadata = {k: v for k, v in meta.items() if k.startswith("ss_")}
|
||||
|
||||
bytes = safetensors.torch.save(state_dict, metadata)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
model_hash = addnet_hash_safetensors(b)
|
||||
legacy_hash = addnet_hash_legacy(b)
|
||||
meta["sshs_model_hash"] = model_hash
|
||||
meta["sshs_legacy_hash"] = legacy_hash
|
||||
return meta
|
||||
|
||||
|
||||
def add_base_model_info_to_meta(
|
||||
meta: OrderedDict,
|
||||
base_model: str = None,
|
||||
is_v1: bool = False,
|
||||
is_v2: bool = False,
|
||||
is_xl: bool = False,
|
||||
) -> OrderedDict:
|
||||
if base_model is not None:
|
||||
meta['ss_base_model'] = base_model
|
||||
elif is_v2:
|
||||
meta['ss_v2'] = True
|
||||
meta['ss_base_model_version'] = 'sd_2.1'
|
||||
|
||||
elif is_xl:
|
||||
meta['ss_base_model_version'] = 'sdxl_1.0'
|
||||
else:
|
||||
# default to v1.5
|
||||
meta['ss_base_model_version'] = 'sd_1.5'
|
||||
return meta
|
||||
|
||||
|
||||
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
parsed_meta = OrderedDict()
|
||||
for key, value in meta.items():
|
||||
|
||||
@@ -54,6 +54,8 @@ def get_optimizer(
|
||||
elif lower_type == 'lion':
|
||||
from lion_pytorch import Lion
|
||||
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == 'adagrad':
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -399,3 +400,29 @@ def concat_prompt_embeddings(
|
||||
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def addnet_hash_legacy(b):
|
||||
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
m = hashlib.sha256()
|
||||
|
||||
b.seek(0x100000)
|
||||
m.update(b.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
|
||||
Reference in New Issue
Block a user