Added Model rescale and prepared a release upgrade

This commit is contained in:
Jaret Burkett
2023-08-01 13:49:54 -06:00
parent 63cacf4362
commit 8b8d53888d
15 changed files with 388 additions and 64 deletions

View File

@@ -40,6 +40,8 @@ pip3 install -r requirements.txt
I have so many hodge podge scripts I am going to be moving over to this that I use in my ML work. But this is what is
here so far.
---
### LoRA (lierla), LoCON (LyCORIS) extractor
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
@@ -64,6 +66,31 @@ Most people used fixed, which is traditional fixed dimension extraction.
`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
---
### LoRA Rescale
Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
Then you can edit the file to your liking. and call it like so:
```bash
python3 run.py config/whatever_you_want.yml
```
You can also put a full path to a config file, if you want to keep it somewhere else.
```bash
python3 run.py "/home/user/whatever_you_want.yml"
```
More notes on how it works are available in the example config file itself. This is useful when making
all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
or even -15 to 15. This will allow you to dile it in so they all have your desired scale
---
### LoRA Slider Trainer
@@ -108,13 +135,32 @@ Just went in and out. It is much worse on smaller faces than shown here.
## TODO
- [X] Add proper regs on sliders
- [ ] Add SDXL support (base model only for now)
- [X] Add SDXL support (base model only for now)
- [ ] Add plain erasing
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
---
## Change Log
#### 2021-08-01
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
at the moment, so hopefully there are not breaking changes.
Unfortunately, I am too lazy to write a proper changelog with all the changes.
I added SDXL training to sliders... but.. it does not work properly.
The slider training relies on a model's ability to understand that an unconditional (negative prompt)
means you do not want that concept in the output. SDXL does not understand this for whatever reason,
which makes separating out
concepts within the model hard. I am sure the community will find a way to fix this
over time, but for now, it is not
going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
training without every experimental new paper added to it. The KISS principal.
#### 2021-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights

View File

@@ -0,0 +1,48 @@
---
job: mod
config:
name: name_of_your_model_v1
process:
- type: rescale_lora
# path to your current lora model
input_path: "/path/to/lora/lora.safetensors"
# output path for your new lora model, can be the same as input_path to replace
output_path: "/path/to/lora/output_lora_v1.safetensors"
# replaces meta with the meta below (plus minimum meta fields)
# if false, we will leave the meta alone except for updating hashes (sd-script hashes)
replace_meta: true
# how to adjust, we can scale the up_down weights or the alpha
# up_down is the default and probably the best, they will both net the same outputs
# would only affect rare NaN cases and maybe merging with old merge tools
scale_target: 'up_down'
# precision to save, fp16 is the default and standard
save_dtype: fp16
# current_weight is the ideal weight you use as a multiplier when using the lora
# IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
# you can do negatives here too if you want to flip the lora
current_weight: 6.0
# target_weight is the ideal weight you use as a multiplier when using the lora
# instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
# we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
target_weight: 1.0
# base model for the lora
# this is just used to add meta so automatic111 knows which model it is for
# assume v1.5 if these are not set
is_xl: false
is_v2: false
meta:
# this is only used if you set replace_meta to true above
name: "[name]" # [name] gets replaced with the name above
description: A short description of your lora
trigger_words:
- put
- trigger
- words
- here
version: '0.1'
creator:
name: Your Name
email: your@email.com
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.

View File

@@ -3,6 +3,6 @@ from collections import OrderedDict
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.1"
v["version"] = "0.0.2"
software_meta = v

28
jobs/ModJob.py Normal file
View File

@@ -0,0 +1,28 @@
import os
from collections import OrderedDict
from jobs import BaseJob
from toolkit.metadata import get_meta_for_safetensors
from toolkit.train_tools import get_torch_dtype
process_dict = {
'rescale_lora': 'ModRescaleLoraProcess',
}
class ModJob(BaseJob):
def __init__(self, config: OrderedDict):
super().__init__(config)
self.device = self.get_conf('device', 'cpu')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -2,3 +2,4 @@ from .BaseJob import BaseJob
from .ExtractJob import ExtractJob
from .TrainJob import TrainJob
from .MergeJob import MergeJob
from .ModJob import ModJob

View File

@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
DDIMScheduler, DDPMScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
guidance_rescale=0.7,
).images[0]
else:
img = pipeline(
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):
dict = OrderedDict({
o_dict = OrderedDict({
"training_info": self.get_training_info()
})
if self.model_config.is_v2:
dict['ss_v2'] = True
dict['ss_base_model_version'] = 'sd_2.1'
o_dict['ss_v2'] = True
o_dict['ss_base_model_version'] = 'sd_2.1'
elif self.model_config.is_xl:
dict['ss_base_model_version'] = 'sdxl_1.0'
o_dict['ss_base_model_version'] = 'sdxl_1.0'
else:
dict['ss_base_model_version'] = 'sd_1.5'
o_dict['ss_base_model_version'] = 'sd_1.5'
dict['ss_output_name'] = self.job.name
o_dict = add_base_model_info_to_meta(
o_dict,
is_v2=self.model_config.is_v2,
is_xl=self.model_config.is_xl,
)
o_dict['ss_output_name'] = self.job.name
self.add_meta(dict)
self.add_meta(o_dict)
def get_training_info(self):
info = OrderedDict({
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
latent_model_input = torch.cat([latents] * 2)
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
# TODO handle other schedulers
sch = KDPM2DiscreteScheduler
# sch = KDPM2DiscreteScheduler
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
)
if self.model_config.is_xl:
if self.custom_pipeline is not None:

View File

@@ -0,0 +1,100 @@
import gc
import os
from collections import OrderedDict
from typing import ForwardRef
import torch
from safetensors.torch import save_file, load_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype
class ModRescaleLoraProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.input_path = self.get_conf('input_path', required=True)
self.output_path = self.get_conf('output_path', required=True)
self.replace_meta = self.get_conf('replace_meta', default=False)
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype)
self.current_weight = self.get_conf('current_weight', required=True, as_type=float)
self.target_weight = self.get_conf('target_weight', required=True, as_type=float)
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool)
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool)
self.progress_bar = None
def run(self):
super().run()
source_state_dict = load_file(self.input_path)
source_meta = load_metadata_from_safetensors(self.input_path)
if self.replace_meta:
self.meta.update(
add_base_model_info_to_meta(
self.meta,
is_xl=self.is_xl,
is_v2=self.is_v2,
)
)
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
else:
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
new_state_dict = OrderedDict()
for key in list(source_state_dict.keys()):
v = source_state_dict[key]
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32'))
# all loras have an alpha, up weight and down weight
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha",
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight",
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight",
# we can rescale by adjusting the alpha or the up weights, or the up and down weights
# I assume doing both up and down would be best all around, but I'm not sure
# some locons also have mid weights, we will leave those alone for now, will work without them
# when adjusting alpha, it is used to calculate the multiplier in a lora module
# - scale = alpha / lora_dim
# - output = layer_out + lora_up_out * multiplier * scale
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \
.to("cpu", dtype=get_torch_dtype('fp32'))
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to("cpu", dtype=get_torch_dtype('fp32'))
# only update alpha
if self.scale_target == 'alpha' and key.endswith('.alpha'):
v = v * total_module_scale
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
v = v * up_down_scale
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype))
save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
save_file(new_state_dict, self.output_path, save_meta)
# cleanup incase there are other jobs
del new_state_dict
del source_state_dict
del source_meta
torch.cuda.empty_cache()
gc.collect()
print(f"Saved to {self.output_path}")

View File

@@ -46,8 +46,8 @@ class EncodedPromptPair:
negative_target,
negative_target_with_neutral,
neutral,
both_targets,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
if self.slider_config.prompt_file:
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
cache = PromptEmbedsCache()
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
# get encoded latents for our prompts
with torch.no_grad():
if self.slider_config.prompt_tensors is not None:
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
# encode empty_prompt
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
save_file(state_dict, self.slider_config.prompt_tensors)
prompt_pairs = []
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
if both or erase_negative:
print("Encoding erase negative")
prompt_pairs += [
# erase standard
EncodedPromptPair(
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
]
if both or enhance_positive:
print("Encoding enhance positive")
prompt_pairs += [
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or enhance_positive:
# if both or enhance_positive:
if both:
print("Encoding erase positive (inverse)")
prompt_pairs += [
# erase inverted
EncodedPromptPair(
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or erase_negative:
# if both or erase_negative:
if both:
print("Encoding enhance negative (inverse)")
prompt_pairs += [
# enhance inverted
EncodedPromptPair(
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
target_class = prompt_pair.target_class
neutral = prompt_pair.neutral
negative = prompt_pair.negative_target
positive = prompt_pair.positive_target
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier

View File

@@ -8,4 +8,5 @@ from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess

View File

@@ -99,5 +99,5 @@ class SliderConfig:
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
self.anchors: List[SliderConfigAnchors] = anchors
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
self.prompt_file: str = kwargs.get('prompt_file', '')
self.prompt_tensors: str = kwargs.get('prompt_tensors', '')
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)

View File

@@ -13,6 +13,9 @@ def get_job(config_path, name=None):
if job == 'train':
from jobs import TrainJob
return TrainJob(config)
if job == 'mod':
from jobs import ModJob
return ModJob(config)
# elif job == 'train':
# from jobs import TrainJob

View File

@@ -6,12 +6,14 @@
import os
import math
from typing import Optional, List, Type, Set, Literal
from collections import OrderedDict
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
from toolkit.metadata import add_model_hash_to_meta
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
@@ -31,7 +33,7 @@ TRAINING_METHODS = Literal[
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
"full", # train all layers
# "notime",
# "xlayer",
# "outxattn",
@@ -48,12 +50,12 @@ class LoRAModule(nn.Module):
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
@@ -102,19 +104,19 @@ class LoRAModule(nn.Module):
def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
@@ -140,7 +142,7 @@ class LoRANetwork(nn.Module):
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)
@@ -157,13 +159,13 @@ class LoRANetwork(nn.Module):
torch.cuda.empty_cache()
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
@@ -212,6 +214,8 @@ class LoRANetwork(nn.Module):
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if metadata is None:
metadata = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
@@ -221,9 +225,10 @@ class LoRANetwork(nn.Module):
for key in list(state_dict.keys()):
if not key.startswith("lora"):
# lora以外除外
# remove any not lora
del state_dict[key]
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:

View File

@@ -1,18 +1,23 @@
import json
from collections import OrderedDict
from io import BytesIO
import safetensors
from safetensors import safe_open
from info import software_meta
from toolkit.train_tools import addnet_hash_legacy
from toolkit.train_tools import addnet_hash_safetensors
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict:
# stringify the meta and reparse OrderedDict to replace [name] with name
meta_string = json.dumps(meta)
if name is not None:
meta_string = meta_string.replace("[name]", name)
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
save_meta["software"] = software_meta
if add_software_info:
save_meta["software"] = software_meta
# safetensors can only be one level deep
for key, value in save_meta.items():
# if not float, int, bool, or str, convert to json string
@@ -21,6 +26,46 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
return save_meta
def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict:
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in meta.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(state_dict, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
meta["sshs_model_hash"] = model_hash
meta["sshs_legacy_hash"] = legacy_hash
return meta
def add_base_model_info_to_meta(
meta: OrderedDict,
base_model: str = None,
is_v1: bool = False,
is_v2: bool = False,
is_xl: bool = False,
) -> OrderedDict:
if base_model is not None:
meta['ss_base_model'] = base_model
elif is_v2:
meta['ss_v2'] = True
meta['ss_base_model_version'] = 'sd_2.1'
elif is_xl:
meta['ss_base_model_version'] = 'sdxl_1.0'
else:
# default to v1.5
meta['ss_base_model_version'] = 'sd_1.5'
return meta
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
parsed_meta = OrderedDict()
for key, value in meta.items():

View File

@@ -54,6 +54,8 @@ def get_optimizer(
elif lower_type == 'lion':
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer

View File

@@ -1,4 +1,5 @@
import argparse
import hashlib
import json
import os
import time
@@ -399,3 +400,29 @@ def concat_prompt_embeddings(
[unconditional.pooled_embeds, conditional.pooled_embeds]
).repeat_interleave(n_imgs, dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]