Added Model rescale and prepared a release upgrade

This commit is contained in:
Jaret Burkett
2023-08-01 13:49:54 -06:00
parent 63cacf4362
commit 8b8d53888d
15 changed files with 388 additions and 64 deletions

View File

@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
DDIMScheduler, DDPMScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
guidance_rescale=0.7,
).images[0]
else:
img = pipeline(
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):
dict = OrderedDict({
o_dict = OrderedDict({
"training_info": self.get_training_info()
})
if self.model_config.is_v2:
dict['ss_v2'] = True
dict['ss_base_model_version'] = 'sd_2.1'
o_dict['ss_v2'] = True
o_dict['ss_base_model_version'] = 'sd_2.1'
elif self.model_config.is_xl:
dict['ss_base_model_version'] = 'sdxl_1.0'
o_dict['ss_base_model_version'] = 'sdxl_1.0'
else:
dict['ss_base_model_version'] = 'sd_1.5'
o_dict['ss_base_model_version'] = 'sd_1.5'
dict['ss_output_name'] = self.job.name
o_dict = add_base_model_info_to_meta(
o_dict,
is_v2=self.model_config.is_v2,
is_xl=self.model_config.is_xl,
)
o_dict['ss_output_name'] = self.job.name
self.add_meta(dict)
self.add_meta(o_dict)
def get_training_info(self):
info = OrderedDict({
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
latent_model_input = torch.cat([latents] * 2)
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
# TODO handle other schedulers
sch = KDPM2DiscreteScheduler
# sch = KDPM2DiscreteScheduler
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
)
if self.model_config.is_xl:
if self.custom_pipeline is not None:

View File

@@ -0,0 +1,100 @@
import gc
import os
from collections import OrderedDict
from typing import ForwardRef
import torch
from safetensors.torch import save_file, load_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype
class ModRescaleLoraProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.input_path = self.get_conf('input_path', required=True)
self.output_path = self.get_conf('output_path', required=True)
self.replace_meta = self.get_conf('replace_meta', default=False)
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype)
self.current_weight = self.get_conf('current_weight', required=True, as_type=float)
self.target_weight = self.get_conf('target_weight', required=True, as_type=float)
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool)
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool)
self.progress_bar = None
def run(self):
super().run()
source_state_dict = load_file(self.input_path)
source_meta = load_metadata_from_safetensors(self.input_path)
if self.replace_meta:
self.meta.update(
add_base_model_info_to_meta(
self.meta,
is_xl=self.is_xl,
is_v2=self.is_v2,
)
)
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
else:
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
new_state_dict = OrderedDict()
for key in list(source_state_dict.keys()):
v = source_state_dict[key]
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32'))
# all loras have an alpha, up weight and down weight
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha",
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight",
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight",
# we can rescale by adjusting the alpha or the up weights, or the up and down weights
# I assume doing both up and down would be best all around, but I'm not sure
# some locons also have mid weights, we will leave those alone for now, will work without them
# when adjusting alpha, it is used to calculate the multiplier in a lora module
# - scale = alpha / lora_dim
# - output = layer_out + lora_up_out * multiplier * scale
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \
.to("cpu", dtype=get_torch_dtype('fp32'))
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to("cpu", dtype=get_torch_dtype('fp32'))
# only update alpha
if self.scale_target == 'alpha' and key.endswith('.alpha'):
v = v * total_module_scale
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
v = v * up_down_scale
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype))
save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
save_file(new_state_dict, self.output_path, save_meta)
# cleanup incase there are other jobs
del new_state_dict
del source_state_dict
del source_meta
torch.cuda.empty_cache()
gc.collect()
print(f"Saved to {self.output_path}")

View File

@@ -46,8 +46,8 @@ class EncodedPromptPair:
negative_target,
negative_target_with_neutral,
neutral,
both_targets,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
if self.slider_config.prompt_file:
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
cache = PromptEmbedsCache()
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
# get encoded latents for our prompts
with torch.no_grad():
if self.slider_config.prompt_tensors is not None:
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
# encode empty_prompt
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
save_file(state_dict, self.slider_config.prompt_tensors)
prompt_pairs = []
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
if both or erase_negative:
print("Encoding erase negative")
prompt_pairs += [
# erase standard
EncodedPromptPair(
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
]
if both or enhance_positive:
print("Encoding enhance positive")
prompt_pairs += [
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or enhance_positive:
# if both or enhance_positive:
if both:
print("Encoding erase positive (inverse)")
prompt_pairs += [
# erase inverted
EncodedPromptPair(
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or erase_negative:
# if both or erase_negative:
if both:
print("Encoding enhance negative (inverse)")
prompt_pairs += [
# enhance inverted
EncodedPromptPair(
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
target_class = prompt_pair.target_class
neutral = prompt_pair.neutral
negative = prompt_pair.negative_target
positive = prompt_pair.positive_target
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier

View File

@@ -8,4 +8,5 @@ from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess