Added Model rescale and prepared a release upgrade

This commit is contained in:
Jaret Burkett
2023-08-01 13:49:54 -06:00
parent 63cacf4362
commit 8b8d53888d
15 changed files with 388 additions and 64 deletions

View File

@@ -6,12 +6,14 @@
import os
import math
from typing import Optional, List, Type, Set, Literal
from collections import OrderedDict
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
from toolkit.metadata import add_model_hash_to_meta
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
@@ -31,7 +33,7 @@ TRAINING_METHODS = Literal[
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
"full", # train all layers
# "notime",
# "xlayer",
# "outxattn",
@@ -48,12 +50,12 @@ class LoRAModule(nn.Module):
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
@@ -102,19 +104,19 @@ class LoRAModule(nn.Module):
def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
@@ -140,7 +142,7 @@ class LoRANetwork(nn.Module):
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)
@@ -157,13 +159,13 @@ class LoRANetwork(nn.Module):
torch.cuda.empty_cache()
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
@@ -212,6 +214,8 @@ class LoRANetwork(nn.Module):
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if metadata is None:
metadata = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
@@ -221,9 +225,10 @@ class LoRANetwork(nn.Module):
for key in list(state_dict.keys()):
if not key.startswith("lora"):
# lora以外除外
# remove any not lora
del state_dict[key]
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else: