mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added Model rescale and prepared a release upgrade
This commit is contained in:
@@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2D
|
||||
DDIMScheduler, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
@@ -192,6 +192,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
guidance_rescale=0.7,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
@@ -236,21 +237,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def update_training_metadata(self):
|
||||
dict = OrderedDict({
|
||||
o_dict = OrderedDict({
|
||||
"training_info": self.get_training_info()
|
||||
})
|
||||
if self.model_config.is_v2:
|
||||
dict['ss_v2'] = True
|
||||
dict['ss_base_model_version'] = 'sd_2.1'
|
||||
o_dict['ss_v2'] = True
|
||||
o_dict['ss_base_model_version'] = 'sd_2.1'
|
||||
|
||||
elif self.model_config.is_xl:
|
||||
dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
else:
|
||||
dict['ss_base_model_version'] = 'sd_1.5'
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
dict['ss_output_name'] = self.job.name
|
||||
o_dict = add_base_model_info_to_meta(
|
||||
o_dict,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_xl=self.model_config.is_xl,
|
||||
)
|
||||
o_dict['ss_output_name'] = self.job.name
|
||||
|
||||
self.add_meta(dict)
|
||||
self.add_meta(o_dict)
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
@@ -381,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0.7,
|
||||
guidance_rescale=0, # 0.7
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -389,7 +395,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
@@ -500,13 +505,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
sch = KDPM2DiscreteScheduler
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
sch = DDPMScheduler
|
||||
# do our own scheduler
|
||||
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
|
||||
100
jobs/process/ModRescaleLoraProcess.py
Normal file
100
jobs/process/ModRescaleLoraProcess.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file, load_file
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
|
||||
add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class ModRescaleLoraProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.input_path = self.get_conf('input_path', required=True)
|
||||
self.output_path = self.get_conf('output_path', required=True)
|
||||
self.replace_meta = self.get_conf('replace_meta', default=False)
|
||||
self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype)
|
||||
self.current_weight = self.get_conf('current_weight', required=True, as_type=float)
|
||||
self.target_weight = self.get_conf('target_weight', required=True, as_type=float)
|
||||
self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down
|
||||
self.is_xl = self.get_conf('is_xl', default=False, as_type=bool)
|
||||
self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool)
|
||||
|
||||
self.progress_bar = None
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
source_state_dict = load_file(self.input_path)
|
||||
source_meta = load_metadata_from_safetensors(self.input_path)
|
||||
|
||||
if self.replace_meta:
|
||||
self.meta.update(
|
||||
add_base_model_info_to_meta(
|
||||
self.meta,
|
||||
is_xl=self.is_xl,
|
||||
is_v2=self.is_v2,
|
||||
)
|
||||
)
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
else:
|
||||
save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False)
|
||||
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
for key in list(source_state_dict.keys()):
|
||||
v = source_state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32'))
|
||||
|
||||
# all loras have an alpha, up weight and down weight
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha",
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight",
|
||||
# - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight",
|
||||
# we can rescale by adjusting the alpha or the up weights, or the up and down weights
|
||||
# I assume doing both up and down would be best all around, but I'm not sure
|
||||
# some locons also have mid weights, we will leave those alone for now, will work without them
|
||||
|
||||
# when adjusting alpha, it is used to calculate the multiplier in a lora module
|
||||
# - scale = alpha / lora_dim
|
||||
# - output = layer_out + lora_up_out * multiplier * scale
|
||||
total_module_scale = torch.tensor(self.current_weight / self.target_weight) \
|
||||
.to("cpu", dtype=get_torch_dtype('fp32'))
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to("cpu", dtype=get_torch_dtype('fp32'))
|
||||
# only update alpha
|
||||
if self.scale_target == 'alpha' and key.endswith('.alpha'):
|
||||
v = v * total_module_scale
|
||||
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
|
||||
v = v * up_down_scale
|
||||
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype))
|
||||
|
||||
save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
|
||||
save_file(new_state_dict, self.output_path, save_meta)
|
||||
|
||||
# cleanup incase there are other jobs
|
||||
del new_state_dict
|
||||
del source_state_dict
|
||||
del source_meta
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
print(f"Saved to {self.output_path}")
|
||||
@@ -46,8 +46,8 @@ class EncodedPromptPair:
|
||||
negative_target,
|
||||
negative_target_with_neutral,
|
||||
neutral,
|
||||
both_targets,
|
||||
empty_prompt,
|
||||
both_targets,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=1.0,
|
||||
weight=1.0
|
||||
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||
|
||||
# read line by line from file
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
if self.slider_config.prompt_file:
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
if self.slider_config.prompt_tensors is not None:
|
||||
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# encode empty_prompt
|
||||
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
|
||||
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
|
||||
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
prompt_list = [
|
||||
f"{target.target_class}", # target_class
|
||||
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
save_file(state_dict, self.slider_config.prompt_tensors)
|
||||
|
||||
prompt_pairs = []
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
erase_negative = len(target.positive.strip()) == 0
|
||||
enhance_positive = len(target.negative.strip()) == 0
|
||||
|
||||
both = not erase_negative and not enhance_positive
|
||||
|
||||
if both or erase_negative:
|
||||
print("Encoding erase negative")
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding enhance positive")
|
||||
prompt_pairs += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
# if both or enhance_positive:
|
||||
if both:
|
||||
print("Encoding erase positive (inverse)")
|
||||
prompt_pairs += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
# if both or erase_negative:
|
||||
if both:
|
||||
print("Encoding enhance negative (inverse)")
|
||||
prompt_pairs += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
|
||||
target_class = prompt_pair.target_class
|
||||
neutral = prompt_pair.neutral
|
||||
negative = prompt_pair.negative_target
|
||||
positive = prompt_pair.positive_target
|
||||
weight = prompt_pair.weight
|
||||
multiplier = prompt_pair.multiplier
|
||||
|
||||
|
||||
@@ -8,4 +8,5 @@ from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainSliderProcessOld import TrainSliderProcessOld
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
|
||||
Reference in New Issue
Block a user