mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Model rescale and prepared a release upgrade
This commit is contained in:
48
README.md
48
README.md
@@ -40,6 +40,8 @@ pip3 install -r requirements.txt
|
||||
I have so many hodge podge scripts I am going to be moving over to this that I use in my ML work. But this is what is
|
||||
here so far.
|
||||
|
||||
---
|
||||
|
||||
### LoRA (lierla), LoCON (LyCORIS) extractor
|
||||
|
||||
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
|
||||
@@ -64,6 +66,31 @@ Most people used fixed, which is traditional fixed dimension extraction.
|
||||
|
||||
`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
|
||||
|
||||
---
|
||||
|
||||
### LoRA Rescale
|
||||
|
||||
Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
|
||||
A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
|
||||
It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
|
||||
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
|
||||
Then you can edit the file to your liking. and call it like so:
|
||||
|
||||
```bash
|
||||
python3 run.py config/whatever_you_want.yml
|
||||
```
|
||||
|
||||
You can also put a full path to a config file, if you want to keep it somewhere else.
|
||||
|
||||
```bash
|
||||
python3 run.py "/home/user/whatever_you_want.yml"
|
||||
```
|
||||
|
||||
More notes on how it works are available in the example config file itself. This is useful when making
|
||||
all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
|
||||
or even -15 to 15. This will allow you to dile it in so they all have your desired scale
|
||||
|
||||
---
|
||||
|
||||
### LoRA Slider Trainer
|
||||
|
||||
@@ -108,13 +135,32 @@ Just went in and out. It is much worse on smaller faces than shown here.
|
||||
|
||||
## TODO
|
||||
- [X] Add proper regs on sliders
|
||||
- [ ] Add SDXL support (base model only for now)
|
||||
- [X] Add SDXL support (base model only for now)
|
||||
- [ ] Add plain erasing
|
||||
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
||||
|
||||
---
|
||||
|
||||
## Change Log
|
||||
|
||||
#### 2021-08-01
|
||||
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
|
||||
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
|
||||
at the moment, so hopefully there are not breaking changes.
|
||||
|
||||
Unfortunately, I am too lazy to write a proper changelog with all the changes.
|
||||
|
||||
I added SDXL training to sliders... but.. it does not work properly.
|
||||
The slider training relies on a model's ability to understand that an unconditional (negative prompt)
|
||||
means you do not want that concept in the output. SDXL does not understand this for whatever reason,
|
||||
which makes separating out
|
||||
concepts within the model hard. I am sure the community will find a way to fix this
|
||||
over time, but for now, it is not
|
||||
going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
|
||||
encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
|
||||
training without every experimental new paper added to it. The KISS principal.
|
||||
|
||||
|
||||
#### 2021-07-30
|
||||
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
|
||||
regularizer. You can set the network multiplier to force spread consistency at high weights
|
||||
|
||||
48
config/examples/mod_lora_scale.yaml
Normal file
48
config/examples/mod_lora_scale.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
job: mod
|
||||
config:
|
||||
name: name_of_your_model_v1
|
||||
process:
|
||||
- type: rescale_lora
|
||||
# path to your current lora model
|
||||
input_path: "/path/to/lora/lora.safetensors"
|
||||
# output path for your new lora model, can be the same as input_path to replace
|
||||
output_path: "/path/to/lora/output_lora_v1.safetensors"
|
||||
# replaces meta with the meta below (plus minimum meta fields)
|
||||
# if false, we will leave the meta alone except for updating hashes (sd-script hashes)
|
||||
replace_meta: true
|
||||
# how to adjust, we can scale the up_down weights or the alpha
|
||||
# up_down is the default and probably the best, they will both net the same outputs
|
||||
# would only affect rare NaN cases and maybe merging with old merge tools
|
||||
scale_target: 'up_down'
|
||||
# precision to save, fp16 is the default and standard
|
||||
save_dtype: fp16
|
||||
# current_weight is the ideal weight you use as a multiplier when using the lora
|
||||
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
|
||||
# you can do negatives here too if you want to flip the lora
|
||||
current_weight: 6.0
|
||||
# target_weight is the ideal weight you use as a multiplier when using the lora
|
||||
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
|
||||
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
|
||||
target_weight: 1.0
|
||||
|
||||
# base model for the lora
|
||||
# this is just used to add meta so automatic111 knows which model it is for
|
||||
# assume v1.5 if these are not set
|
||||
is_xl: false
|
||||
is_v2: false
|
||||
meta:
|
||||
# this is only used if you set replace_meta to true above
|
||||
name: "[name]" # [name] gets replaced with the name above
|
||||
description: A short description of your lora
|
||||
trigger_words:
|
||||
- put
|
||||
- trigger
|
||||
- words
|
||||
- here
|
||||
version: '0.1'
|
||||
creator:
|
||||
name: Your Name
|
||||
email: your@email.com
|
||||
website: https://yourwebsite.com
|
||||
any: All meta data above is arbitrary, it can be whatever you want.
|
||||
2
info.py
2
info.py
@@ -3,6 +3,6 @@ from collections import OrderedDict
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.0.1"
|
||||
v["version"] = "0.0.2"
|
||||
|
||||
software_meta = v
|
||||
|
||||
28
jobs/ModJob.py
Normal file
28
jobs/ModJob.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from jobs import BaseJob
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
process_dict = {
|
||||
'rescale_lora': 'ModRescaleLoraProcess',
|
||||
}
|
||||
|
||||
|
||||
class ModJob(BaseJob):
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
@@ -2,3 +2,4 @@ from .BaseJob import BaseJob
|
||||
from .ExtractJob import ExtractJob
|
||||
from .TrainJob import TrainJob
|
||||
from .MergeJob import MergeJob
|
||||
from .ModJob import ModJob
|
||||
|
||||
@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
|
||||
DDIMScheduler, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
guidance_rescale=0.7,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def update_training_metadata(self):
|
||||
dict = OrderedDict({
|
||||
o_dict = OrderedDict({
|
||||
"training_info": self.get_training_info()
|
||||
})
|
||||
if self.model_config.is_v2:
|
||||
dict['ss_v2'] = True
|
||||
dict['ss_base_model_version'] = 'sd_2.1'
|
||||
o_dict['ss_v2'] = True
|
||||
o_dict['ss_base_model_version'] = 'sd_2.1'
|
||||
|
||||
elif self.model_config.is_xl:
|
||||
dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
else:
|
||||
dict['ss_base_model_version'] = 'sd_1.5'
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
dict['ss_output_name'] = self.job.name
|
||||
o_dict = add_base_model_info_to_meta(
|
||||
o_dict,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_xl=self.model_config.is_xl,
|
||||
)
|
||||
o_dict['ss_output_name'] = self.job.name
|
||||
|
||||
self.add_meta(dict)
|
||||
self.add_meta(o_dict)
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0.7,
|
||||
guidance_rescale=0, # 0.7
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
sch = KDPM2DiscreteScheduler
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
sch = DDPMScheduler
|
||||
# do our own scheduler
|
||||
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
|
||||
100
jobs/process/ModRescaleLoraProcess.py
Normal file
100
jobs/process/ModRescaleLoraProcess.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file, load_file
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
|
||||
add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class ModRescaleLoraProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.input_path = self.get_conf('input_path', required=True)
|
||||
self.output_path = self.get_conf('output_path', required=True)
|
||||
self.replace_meta = self.get_conf('replace_meta', default=False)
|
||||
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype)
|
||||
self.current_weight = self.get_conf('current_weight', required=True, as_type=float)
|
||||
self.target_weight = self.get_conf('target_weight', required=True, as_type=float)
|
||||
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down
|
||||
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool)
|
||||
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool)
|
||||
|
||||
self.progress_bar = None
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
source_state_dict = load_file(self.input_path)
|
||||
source_meta = load_metadata_from_safetensors(self.input_path)
|
||||
|
||||
if self.replace_meta:
|
||||
self.meta.update(
|
||||
add_base_model_info_to_meta(
|
||||
self.meta,
|
||||
is_xl=self.is_xl,
|
||||
is_v2=self.is_v2,
|
||||
)
|
||||
)
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
else:
|
||||
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False)
|
||||
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
for key in list(source_state_dict.keys()):
|
||||
v = source_state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32'))
|
||||
|
||||
# all loras have an alpha, up weight and down weight
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha",
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight",
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight",
|
||||
# we can rescale by adjusting the alpha or the up weights, or the up and down weights
|
||||
# I assume doing both up and down would be best all around, but I'm not sure
|
||||
# some locons also have mid weights, we will leave those alone for now, will work without them
|
||||
|
||||
# when adjusting alpha, it is used to calculate the multiplier in a lora module
|
||||
# - scale = alpha / lora_dim
|
||||
# - output = layer_out + lora_up_out * multiplier * scale
|
||||
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \
|
||||
.to("cpu", dtype=get_torch_dtype('fp32'))
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to("cpu", dtype=get_torch_dtype('fp32'))
|
||||
# only update alpha
|
||||
if self.scale_target == 'alpha' and key.endswith('.alpha'):
|
||||
v = v * total_module_scale
|
||||
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
|
||||
v = v * up_down_scale
|
||||
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype))
|
||||
|
||||
save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
|
||||
save_file(new_state_dict, self.output_path, save_meta)
|
||||
|
||||
# cleanup incase there are other jobs
|
||||
del new_state_dict
|
||||
del source_state_dict
|
||||
del source_meta
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
print(f"Saved to {self.output_path}")
|
||||
@@ -46,8 +46,8 @@ class EncodedPromptPair:
|
||||
negative_target,
|
||||
negative_target_with_neutral,
|
||||
neutral,
|
||||
both_targets,
|
||||
empty_prompt,
|
||||
both_targets,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=1.0,
|
||||
weight=1.0
|
||||
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||
|
||||
# read line by line from file
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
if self.slider_config.prompt_file:
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
if self.slider_config.prompt_tensors is not None:
|
||||
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# encode empty_prompt
|
||||
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
|
||||
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
|
||||
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
prompt_list = [
|
||||
f"{target.target_class}", # target_class
|
||||
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
save_file(state_dict, self.slider_config.prompt_tensors)
|
||||
|
||||
prompt_pairs = []
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
erase_negative = len(target.positive.strip()) == 0
|
||||
enhance_positive = len(target.negative.strip()) == 0
|
||||
|
||||
both = not erase_negative and not enhance_positive
|
||||
|
||||
if both or erase_negative:
|
||||
print("Encoding erase negative")
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding enhance positive")
|
||||
prompt_pairs += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
# if both or enhance_positive:
|
||||
if both:
|
||||
print("Encoding erase positive (inverse)")
|
||||
prompt_pairs += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
# if both or erase_negative:
|
||||
if both:
|
||||
print("Encoding enhance negative (inverse)")
|
||||
prompt_pairs += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
|
||||
target_class = prompt_pair.target_class
|
||||
neutral = prompt_pair.neutral
|
||||
negative = prompt_pair.negative_target
|
||||
positive = prompt_pair.positive_target
|
||||
weight = prompt_pair.weight
|
||||
multiplier = prompt_pair.multiplier
|
||||
|
||||
|
||||
@@ -8,4 +8,5 @@ from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainSliderProcessOld import TrainSliderProcessOld
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
|
||||
@@ -99,5 +99,5 @@ class SliderConfig:
|
||||
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
||||
self.anchors: List[SliderConfigAnchors] = anchors
|
||||
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
||||
self.prompt_file: str = kwargs.get('prompt_file', '')
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', '')
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
|
||||
@@ -13,6 +13,9 @@ def get_job(config_path, name=None):
|
||||
if job == 'train':
|
||||
from jobs import TrainJob
|
||||
return TrainJob(config)
|
||||
if job == 'mod':
|
||||
from jobs import ModJob
|
||||
return ModJob(config)
|
||||
|
||||
# elif job == 'train':
|
||||
# from jobs import TrainJob
|
||||
|
||||
@@ -6,12 +6,14 @@
|
||||
import os
|
||||
import math
|
||||
from typing import Optional, List, Type, Set, Literal
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
||||
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
|
||||
@@ -31,7 +33,7 @@ TRAINING_METHODS = Literal[
|
||||
"innoxattn", # train all layers except self attention layers
|
||||
"selfattn", # ESD-u, train only self attention layers
|
||||
"xattn", # ESD-x, train only x attention layers
|
||||
"full", # train all layers
|
||||
"full", # train all layers
|
||||
# "notime",
|
||||
# "xlayer",
|
||||
# "outxattn",
|
||||
@@ -48,12 +50,12 @@ class LoRAModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
@@ -102,19 +104,19 @@ class LoRAModule(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
)
|
||||
|
||||
|
||||
class LoRANetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
rank: int = 4,
|
||||
multiplier: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
train_method: TRAINING_METHODS = "full",
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
rank: int = 4,
|
||||
multiplier: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
train_method: TRAINING_METHODS = "full",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -140,7 +142,7 @@ class LoRANetwork(nn.Module):
|
||||
lora_names = set()
|
||||
for lora in self.unet_loras:
|
||||
assert (
|
||||
lora.lora_name not in lora_names
|
||||
lora.lora_name not in lora_names
|
||||
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
||||
lora_names.add(lora.lora_name)
|
||||
|
||||
@@ -157,13 +159,13 @@ class LoRANetwork(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_modules(
|
||||
self,
|
||||
prefix: str,
|
||||
root_module: nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
rank: int,
|
||||
multiplier: float,
|
||||
train_method: TRAINING_METHODS,
|
||||
self,
|
||||
prefix: str,
|
||||
root_module: nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
rank: int,
|
||||
multiplier: float,
|
||||
train_method: TRAINING_METHODS,
|
||||
) -> list:
|
||||
loras = []
|
||||
|
||||
@@ -212,6 +214,8 @@ class LoRANetwork(nn.Module):
|
||||
|
||||
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
||||
state_dict = self.state_dict()
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
@@ -221,9 +225,10 @@ class LoRANetwork(nn.Module):
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.startswith("lora"):
|
||||
# lora以外除外
|
||||
# remove any not lora
|
||||
del state_dict[key]
|
||||
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from io import BytesIO
|
||||
|
||||
import safetensors
|
||||
from safetensors import safe_open
|
||||
|
||||
from info import software_meta
|
||||
from toolkit.train_tools import addnet_hash_legacy
|
||||
from toolkit.train_tools import addnet_hash_safetensors
|
||||
|
||||
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict:
|
||||
# stringify the meta and reparse OrderedDict to replace [name] with name
|
||||
meta_string = json.dumps(meta)
|
||||
if name is not None:
|
||||
meta_string = meta_string.replace("[name]", name)
|
||||
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
|
||||
save_meta["software"] = software_meta
|
||||
if add_software_info:
|
||||
save_meta["software"] = software_meta
|
||||
# safetensors can only be one level deep
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
@@ -21,6 +26,46 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
return save_meta
|
||||
|
||||
|
||||
def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict:
|
||||
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||
save time on indexing the model later."""
|
||||
|
||||
# Because writing user metadata to the file can change the result of
|
||||
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||
# calculating the hash, as they are meant to be immutable
|
||||
metadata = {k: v for k, v in meta.items() if k.startswith("ss_")}
|
||||
|
||||
bytes = safetensors.torch.save(state_dict, metadata)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
model_hash = addnet_hash_safetensors(b)
|
||||
legacy_hash = addnet_hash_legacy(b)
|
||||
meta["sshs_model_hash"] = model_hash
|
||||
meta["sshs_legacy_hash"] = legacy_hash
|
||||
return meta
|
||||
|
||||
|
||||
def add_base_model_info_to_meta(
|
||||
meta: OrderedDict,
|
||||
base_model: str = None,
|
||||
is_v1: bool = False,
|
||||
is_v2: bool = False,
|
||||
is_xl: bool = False,
|
||||
) -> OrderedDict:
|
||||
if base_model is not None:
|
||||
meta['ss_base_model'] = base_model
|
||||
elif is_v2:
|
||||
meta['ss_v2'] = True
|
||||
meta['ss_base_model_version'] = 'sd_2.1'
|
||||
|
||||
elif is_xl:
|
||||
meta['ss_base_model_version'] = 'sdxl_1.0'
|
||||
else:
|
||||
# default to v1.5
|
||||
meta['ss_base_model_version'] = 'sd_1.5'
|
||||
return meta
|
||||
|
||||
|
||||
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
parsed_meta = OrderedDict()
|
||||
for key, value in meta.items():
|
||||
|
||||
@@ -54,6 +54,8 @@ def get_optimizer(
|
||||
elif lower_type == 'lion':
|
||||
from lion_pytorch import Lion
|
||||
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == 'adagrad':
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -399,3 +400,29 @@ def concat_prompt_embeddings(
|
||||
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def addnet_hash_legacy(b):
|
||||
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
m = hashlib.sha256()
|
||||
|
||||
b.seek(0x100000)
|
||||
m.update(b.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
|
||||
Reference in New Issue
Block a user