Complete reqork of how slider training works and optimized it to hell. Can run entire algorythm in 1 batch now with less VRAM consumption than a quarter of it used to take

This commit is contained in:
Jaret Burkett
2023-08-05 18:46:08 -06:00
parent 7e4e660663
commit 8c90fa86c6
10 changed files with 944 additions and 379 deletions

View File

@@ -170,18 +170,27 @@ Just went in and out. It is much worse on smaller faces than shown here.
## Change Log ## Change Log
#### 2023-08-05
- Huge memory rework and slider rework. Slider training is better thant ever with no more
ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
accumulation. This makes it much faster and more stable.
- Updated the example config to be something more practical and more updated to current methods. It is now
a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
6GB gpu now. Will test soon to verify.
#### 2021-10-20 #### 2021-10-20
- Windows support bug fixes - Windows support bug fixes
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever. - Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
check out the example in the `extensions` folder. Read more about that above. check out the example in the `extensions` folder. Read more about that above.
- Model Merging, provided via the example extension. - Model Merging, provided via the example extension.
#### 2021-08-03 #### 2023-08-03
Another big refactor to make SD more modular. Another big refactor to make SD more modular.
Made batch image generation script Made batch image generation script
#### 2021-08-01 #### 2023-08-01
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
at the moment, so hopefully there are not breaking changes. at the moment, so hopefully there are not breaking changes.
@@ -199,7 +208,7 @@ encoders to the model as well as a few more entirely separate diffusion networks
training without every experimental new paper added to it. The KISS principal. training without every experimental new paper added to it. The KISS principal.
#### 2021-07-30 #### 2023-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights regularizer. You can set the network multiplier to force spread consistency at high weights

View File

@@ -7,7 +7,7 @@ job: train
config: config:
# the name will be used to create a folder in the output folder # the name will be used to create a folder in the output folder
# it will also replace any [name] token in the rest of this config # it will also replace any [name] token in the rest of this config
name: pet_slider_v1 name: detail_slider_v1
# folder will be created with name above in folder below # folder will be created with name above in folder below
# it can be relative to the project root or absolute # it can be relative to the project root or absolute
training_folder: "output/LoRA" training_folder: "output/LoRA"
@@ -24,7 +24,7 @@ config:
type: "lierla" type: "lierla"
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
rank: 8 rank: 8
alpha: 1.0 # just leave it alpha: 4 # Do about half of rank
# training config # training config
train: train:
@@ -33,7 +33,7 @@ config:
# how many steps to train. More is not always better. I rarely go over 1000 # how many steps to train. More is not always better. I rarely go over 1000
steps: 500 steps: 500
# I have had good results with 4e-4 to 1e-4 at 500 steps # I have had good results with 4e-4 to 1e-4 at 500 steps
lr: 1e-4 lr: 2e-4
# enables gradient checkpoint, saves vram, leave it on # enables gradient checkpoint, saves vram, leave it on
gradient_checkpointing: true gradient_checkpointing: true
# train the unet. I recommend leaving this true # train the unet. I recommend leaving this true
@@ -43,6 +43,7 @@ config:
# not the description of it (text encoder) # not the description of it (text encoder)
train_text_encoder: false train_text_encoder: false
# just leave unless you know what you are doing # just leave unless you know what you are doing
# also supports "dadaptation" but set lr to 1 if you use that, # also supports "dadaptation" but set lr to 1 if you use that,
# but it learns too fast and I don't recommend it # but it learns too fast and I don't recommend it
@@ -53,6 +54,7 @@ config:
# while training. Just leave it # while training. Just leave it
max_denoising_steps: 40 max_denoising_steps: 40
# works great at 1. I do 1 even with my 4090. # works great at 1. I do 1 even with my 4090.
# higher may not work right with newer single batch stacking code anyway
batch_size: 1 batch_size: 1
# bf16 works best if your GPU supports it (modern) # bf16 works best if your GPU supports it (modern)
dtype: bf16 # fp32, bf16, fp16 dtype: bf16 # fp32, bf16, fp16
@@ -69,12 +71,17 @@ config:
name_or_path: "runwayml/stable-diffusion-v1-5" name_or_path: "runwayml/stable-diffusion-v1-5"
is_v2: false # for v2 models is_v2: false # for v2 models
is_v_pred: false # for v-prediction models (most v2 models) is_v_pred: false # for v-prediction models (most v2 models)
# has some issues with the dual text encoder and the way we train sliders
# it works bit weights need to probably be higher to see it.
is_xl: false # for SDXL models is_xl: false # for SDXL models
# saving config # saving config
save: save:
dtype: float16 # precision to save. I recommend float16 dtype: float16 # precision to save. I recommend float16
save_every: 50 # save every this many steps save_every: 50 # save every this many steps
# this will remove step counts more than this number
# allows you to save more often in case of a crash without filling up your drive
max_step_saves_to_keep: 2
# sampling config # sampling config
sample: sample:
@@ -92,21 +99,22 @@ config:
# --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
# slide are good tests. will inherit sample.network_multiplier if not set # slide are good tests. will inherit sample.network_multiplier if not set
# --n [string] # negative prompt, will inherit sample.neg if not set # --n [string] # negative prompt, will inherit sample.neg if not set
# Only 75 tokens allowed currently # Only 75 tokens allowed currently
prompts: # our example is an animal slider, neg: dog, pos: cat # I like to do a wide positive and negative spread so I can see a good range and stop
- "a golden retriever --m -5" # early if the network is braking down
- "a golden retriever --m -3" prompts:
- "a golden retriever --m 3" - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
- "a golden retriever --m 5" - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
- "calico cat --m -5" - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
- "calico cat --m -3" - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
- "calico cat --m 3" - "a golden retriever sitting on a leather couch, --m -5"
- "calico cat --m 5" - "a golden retriever sitting on a leather couch --m -3"
- "an elephant --m -5" - "a golden retriever sitting on a leather couch --m 3"
- "an elephant --m -3" - "a golden retriever sitting on a leather couch --m 5"
- "an elephant --m 3" - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
- "an elephant --m 5" - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
- "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
# negative prompt used on all prompts above as default if they don't have one # negative prompt used on all prompts above as default if they don't have one
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome" neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
# seed for sampling. 42 is the answer for everything # seed for sampling. 42 is the answer for everything
@@ -135,11 +143,16 @@ config:
# resolutions to train on. [ width, height ]. This is less important for sliders # resolutions to train on. [ width, height ]. This is less important for sliders
# as we are not teaching the model anything it doesn't already know # as we are not teaching the model anything it doesn't already know
# but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1 # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
# and [ 1024, 1024 ] for sd_xl
# you can do as many as you want here # you can do as many as you want here
resolutions: resolutions:
- [ 512, 512 ] - [ 512, 512 ]
# - [ 512, 768 ] # - [ 512, 768 ]
# - [ 768, 768 ] # - [ 768, 768 ]
# slider training uses 4 combined steps for a single round. This will do it in one gradient
# step. It is highly optimized and shouldn't take anymore vram than doing without it,
# since we break down batches for gradient accumulation now. so just leave it on.
batch_full_slide: true
# These are the concepts to train on. You can do as many as you want here, # These are the concepts to train on. You can do as many as you want here,
# but they can conflict outweigh each other. Other than experimenting, I recommend # but they can conflict outweigh each other. Other than experimenting, I recommend
# just doing one for good results # just doing one for good results
@@ -150,7 +163,9 @@ config:
# a keyword necessarily but what the model understands the concept to represent. # a keyword necessarily but what the model understands the concept to represent.
# "person" will affect men, women, children, etc but will not affect cats, dogs, etc # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
# it is the models base general understanding of the concept and everything it represents # it is the models base general understanding of the concept and everything it represents
- target_class: "animal" # you can leave it blank to affect everything. In this example, we are adjusting
# detail, so we will leave it blank to affect everything
- target_class: ""
# positive is the prompt for the positive side of the slider. # positive is the prompt for the positive side of the slider.
# It is the concept that will be excited and amplified in the model when we slide the slider # It is the concept that will be excited and amplified in the model when we slide the slider
# to the positive side and forgotten / inverted when we slide # to the positive side and forgotten / inverted when we slide
@@ -158,33 +173,44 @@ config:
# the prompt. You want it to be the extreme of what you want to train on. For example, # the prompt. You want it to be the extreme of what you want to train on. For example,
# if you want to train on fat people, you would use "an extremely fat, morbidly obese person" # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
# as the prompt. Not just "fat person" # as the prompt. Not just "fat person"
positive: "cat" # max 75 tokens for now
positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
# negative is the prompt for the negative side of the slider and works the same as positive # negative is the prompt for the negative side of the slider and works the same as positive
# it does not necessarily work the same as a negative prompt when generating images # it does not necessarily work the same as a negative prompt when generating images
negative: "dog" # these need to be polar opposites.
# max 76 tokens for now
negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
# the loss for this target is multiplied by this number. # the loss for this target is multiplied by this number.
# if you are doing more than one target it may be good to set less important ones # if you are doing more than one target it may be good to set less important ones
# to a lower number like 0.1 so they dont outweigh the primary target # to a lower number like 0.1 so they don't outweigh the primary target
weight: 1.0 weight: 1.0
# anchors are prompts that wer try to hold on to while training the slider
# you want these to generate an image very similar to the target_class # anchors are prompts that we will try to hold on to while training the slider
# without directly overlapping it. For example, if you are training on a person smiling, # these are NOT necessary and can prevent the slider from converging if not done right
# you would use "a person with a face mask" as an anchor. It is a person, the image is the same # leave them off if you are having issues, but they can help lock the network
# regardless if they are smiling or not # on certain concepts to help prevent catastrophic forgetting
anchors: # you want these to generate an image that is not your target_class, but close to it
# only positive prompt for now # is fine as long as it does not directly overlap it.
- prompt: "a woman" # For example, if you are training on a person smiling,
neg_prompt: "animal" # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
# the multiplier applied to the LoRA when this is run. # regardless if they are smiling or not, however, the closer the concept is to the target_class
# higher will give it more weight but also help keep the lora from collapsing # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
multiplier: 8.0 # for close concepts, you want to be closer to 0.1 or 0.2
- prompt: "a man" # these will slow down training. I am leaving them off for the demo
neg_prompt: "animal"
multiplier: 8.0 # anchors:
- prompt: "a person" # - prompt: "a woman"
neg_prompt: "animal" # neg_prompt: "animal"
multiplier: 8.0 # # the multiplier applied to the LoRA when this is run.
# # higher will give it more weight but also help keep the lora from collapsing
# multiplier: 1.0
# - prompt: "a man"
# neg_prompt: "animal"
# multiplier: 1.0
# - prompt: "a person"
# neg_prompt: "animal"
# multiplier: 1.0
# You can put any information you want here, and it will be saved in the model. # You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want. # The below is an example, but you can put your grocery list in it if you want.

View File

@@ -3,6 +3,6 @@ from collections import OrderedDict
v = OrderedDict() v = OrderedDict()
v["name"] = "ai-toolkit" v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.3" v["version"] = "0.0.4"
software_meta = v software_meta = v

View File

@@ -242,6 +242,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing: if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
# if isinstance(text_encoder, list):
# for te in text_encoder:
# te.enable_gradient_checkpointing()
# else:
# text_encoder.enable_gradient_checkpointing()
unet.to(self.device_torch, dtype=dtype) unet.to(self.device_torch, dtype=dtype)
unet.requires_grad_(False) unet.requires_grad_(False)
unet.eval() unet.eval()
@@ -281,6 +287,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
default_lr=self.train_config.lr default_lr=self.train_config.lr
) )
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
latest_save_path = self.get_latest_save_path() latest_save_path = self.get_latest_save_path()
if latest_save_path is not None: if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")

View File

@@ -3,12 +3,14 @@
import random import random
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Optional from typing import Optional, Union
from safetensors.torch import save_file, load_file from safetensors.torch import save_file, load_file
import torch.utils.checkpoint as cp
from tqdm import tqdm from tqdm import tqdm
from toolkit.config_modules import SliderConfig from toolkit.config_modules import SliderConfig
from toolkit.layers import CheckpointGradients
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
import sys import sys
@@ -16,88 +18,21 @@ from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
import gc import gc
from toolkit import train_tools from toolkit import train_tools
from toolkit.prompt_utils import \
EncodedPromptPair, ACTION_TYPES_SLIDER, \
EncodedAnchor, concat_prompt_pairs, \
concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \
split_prompt_pairs
import torch import torch
from .BaseSDTrainProcess import BaseSDTrainProcess from .BaseSDTrainProcess import BaseSDTrainProcess
class ACTION_TYPES_SLIDER:
ERASE_NEGATIVE = 0
ENHANCE_NEGATIVE = 1
def flush(): def flush():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
class EncodedPromptPair:
def __init__(
self,
target_class,
target_class_with_neutral,
positive_target,
positive_target_with_neutral,
negative_target,
negative_target_with_neutral,
neutral,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
):
self.target_class = target_class
self.target_class_with_neutral = target_class_with_neutral
self.positive_target = positive_target
self.positive_target_with_neutral = positive_target_with_neutral
self.negative_target = negative_target
self.negative_target_with_neutral = negative_target_with_neutral
self.neutral = neutral
self.empty_prompt = empty_prompt
self.both_targets = both_targets
self.multiplier = multiplier
self.action: int = action
self.weight = weight
# simulate torch to for tensors
def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs)
self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
self.neutral = self.neutral.to(*args, **kwargs)
self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
class PromptEmbedsCache:
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier
class TrainSliderProcess(BaseSDTrainProcess): class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict): def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
@@ -110,6 +45,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.prompt_cache = PromptEmbedsCache() self.prompt_cache = PromptEmbedsCache()
self.prompt_pairs: list[EncodedPromptPair] = [] self.prompt_pairs: list[EncodedPromptPair] = []
self.anchor_pairs: list[EncodedAnchor] = [] self.anchor_pairs: list[EncodedAnchor] = []
# keep track of prompt chunk size
self.prompt_chunk_size = 1
def before_model_load(self): def before_model_load(self):
pass pass
@@ -137,163 +74,57 @@ class TrainSliderProcess(BaseSDTrainProcess):
# get encoded latents for our prompts # get encoded latents for our prompts
with torch.no_grad(): with torch.no_grad():
if self.slider_config.prompt_tensors is not None: # list of neutrals. Can come from file or be empty
# check to see if it exists neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
if os.path.exists(self.slider_config.prompt_tensors):
# load it.
self.print(f"Loading prompt tensors from {self.slider_config.prompt_tensors}")
prompt_tensors = load_file(self.slider_config.prompt_tensors, device='cpu')
# add them to the cache
for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
if prompt_txt.startswith("te:"):
prompt = prompt_txt[3:]
# text_embeds
text_embeds = prompt_tensor
pooled_embeds = None
# find pool embeds
if f"pe:{prompt}" in prompt_tensors:
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
# make it # build the prompts to cache
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) prompts_to_cache = []
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) for neutral in neutral_list:
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral
f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target
f"{target.negative} {neutral}", # negative_target with neutral
f"{neutral}", # neutral
f"{target.positive} {target.negative}", # both targets
f"{target.negative} {target.positive}", # both targets reverse
]
prompts_to_cache += prompt_list
if len(cache.prompts) == 0: # remove duplicates
print("Prompt tensors not found. Encoding prompts..") prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
empty_prompt = ""
# encode empty_prompt
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""] # encode them
cache = encode_prompts_to_cache(
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): prompt_list=prompts_to_cache,
for target in self.slider_config.targets: sd=self.sd,
prompt_list = [ cache=cache,
f"{target.target_class}", # target_class prompt_tensor_file=self.slider_config.prompt_tensors
f"{target.target_class} {neutral}", # target_class with neutral )
f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target
f"{target.negative} {neutral}", # negative_target with neutral
f"{neutral}", # neutral
f"{target.positive} {target.negative}", # both targets
f"{target.negative} {target.positive}", # both targets
]
for p in prompt_list:
# build the cache
if cache[p] is None:
cache[p] = self.sd.encode_prompt(p).to(device="cpu", dtype=torch.float32)
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
if erase_negative and enhance_positive:
raise ValueError("target must have at least one of positive or negative or both")
# for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network
# if we don't do this, we will get different contrast and focus.
# we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot
if self.slider_config.prompt_tensors:
print(f"Saving prompt tensors to {self.slider_config.prompt_tensors}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu",
dtype=get_torch_dtype('fp16'))
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu",
dtype=get_torch_dtype(
'fp16'))
save_file(state_dict, self.slider_config.prompt_tensors)
prompt_pairs = [] prompt_pairs = []
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): prompt_batches = []
for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False):
for target in self.slider_config.targets: for target in self.slider_config.targets:
erase_negative = len(target.positive.strip()) == 0 prompt_pair_batch = build_prompt_pair_batch_from_cache(
enhance_positive = len(target.negative.strip()) == 0 cache=cache,
target=target,
neutral=neutral,
both = not erase_negative and not enhance_positive )
if self.slider_config.batch_full_slide:
if both or erase_negative: # concat the prompt pairs
print("Encoding erase negative") # this allows us to run the entire 4 part process in one shot (for slider)
prompt_pairs += [ self.prompt_chunk_size = 4
# erase standard concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu')
EncodedPromptPair( prompt_pairs += [concat_prompt_pair_batch]
target_class=cache[target.target_class], else:
target_class_with_neutral=cache[f"{target.target_class} {neutral}"], self.prompt_chunk_size = 1
positive_target=cache[f"{target.positive}"], # do them one at a time (probably not necessary after new optimizations)
positive_target_with_neutral=cache[f"{target.positive} {neutral}"], prompt_pairs += [x.to('cpu') for x in prompt_pair_batch]
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
if both or enhance_positive:
print("Encoding enhance positive")
prompt_pairs += [
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
# if both or enhance_positive:
if both:
print("Encoding erase positive (inverse)")
prompt_pairs += [
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
# if both or erase_negative:
if both:
print("Encoding enhance negative (inverse)")
prompt_pairs += [
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
both_targets=cache[f"{target.positive} {target.negative}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
# setup anchors # setup anchors
anchor_pairs = [] anchor_pairs = []
@@ -306,13 +137,26 @@ class TrainSliderProcess(BaseSDTrainProcess):
if cache[prompt] == None: if cache[prompt] == None:
cache[prompt] = self.sd.encode_prompt(prompt) cache[prompt] = self.sd.encode_prompt(prompt)
anchor_batch = []
# we get the prompt pair multiplier from first prompt pair
# since they are all the same. We need to match their network polarity
prompt_pair_multipliers = prompt_pairs[0].multiplier_list
for prompt_multiplier in prompt_pair_multipliers:
# match the network multiplier polarity
anchor_scalar = 1.0 if prompt_multiplier > 0 else -1.0
anchor_batch += [
EncodedAnchor(
prompt=cache[anchor.prompt],
neg_prompt=cache[anchor.neg_prompt],
multiplier=anchor.multiplier * anchor_scalar
)
]
anchor_pairs += [ anchor_pairs += [
EncodedAnchor( concat_anchors(anchor_batch).to('cpu')
prompt=cache[anchor.prompt],
neg_prompt=cache[anchor.neg_prompt],
multiplier=anchor.multiplier
)
] ]
if len(anchor_pairs) > 0:
self.anchor_pairs = anchor_pairs
# move to cpu to save vram # move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling # We don't need text encoder anymore, but keep it on cpu for sampling
@@ -324,17 +168,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.sd.text_encoder.to("cpu") self.sd.text_encoder.to("cpu")
self.prompt_cache = cache self.prompt_cache = cache
self.prompt_pairs = prompt_pairs self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs # self.anchor_pairs = anchor_pairs
flush() flush()
# end hook_before_train_loop # end hook_before_train_loop
def hook_train_loop(self): def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
# get random multiplier between 1 and 3
rand_weight = 1
# rand_weight = torch.rand((1,)).item() * 2 + 1
# get a random pair # get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[ prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item() torch.randint(0, len(self.prompt_pairs), (1,)).item()
@@ -346,11 +186,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
height, width = self.slider_config.resolutions[ height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item() torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
] ]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier
unet = self.sd.unet
noise_scheduler = self.sd.noise_scheduler noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer optimizer = self.optimizer
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
@@ -368,9 +207,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
guidance_scale=gs, guidance_scale=gs,
) )
# set network multiplier
self.network.multiplier = multiplier * rand_weight
with torch.no_grad(): with torch.no_grad():
self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch self.train_config.max_denoising_steps, device=self.device_torch
@@ -383,11 +219,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
1, self.train_config.max_denoising_steps, (1,) 1, self.train_config.max_denoising_steps, (1,)
).item() ).item()
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
# get noise # get noise
noise = self.sd.get_latent_noise( noise = self.sd.get_latent_noise(
pixel_height=height, pixel_height=height,
pixel_width=width, pixel_width=width,
batch_size=self.train_config.batch_size, batch_size=true_batch_size,
noise_offset=self.train_config.noise_offset, noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
@@ -397,7 +236,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
self.network.multiplier = multiplier * rand_weight # pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps( denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents latents, # pass simple noise latents
train_tools.concat_prompt_embeddings( train_tools.concat_prompt_embeddings(
@@ -410,19 +250,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
guidance_scale=3, guidance_scale=3,
) )
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
noise_scheduler.set_timesteps(1000) noise_scheduler.set_timesteps(1000)
current_timestep = noise_scheduler.timesteps[ current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / self.train_config.max_denoising_steps) int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
] ]
# flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512
positive_latents = get_noise_pred( positive_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt prompt_pair.positive_target, # negative prompt
prompt_pair.negative_target, # positive prompt prompt_pair.negative_target, # positive prompt
1, 1,
current_timestep, current_timestep,
denoised_latents denoised_latents
).to("cpu", dtype=torch.float32) )
positive_latents.requires_grad = False
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
neutral_latents = get_noise_pred( neutral_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt prompt_pair.positive_target, # negative prompt
@@ -430,7 +278,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
1, 1,
current_timestep, current_timestep,
denoised_latents denoised_latents
).to("cpu", dtype=torch.float32) )
neutral_latents.requires_grad = False
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
unconditional_latents = get_noise_pred( unconditional_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt prompt_pair.positive_target, # negative prompt
@@ -438,87 +288,142 @@ class TrainSliderProcess(BaseSDTrainProcess):
1, 1,
current_timestep, current_timestep,
denoised_latents denoised_latents
).to("cpu", dtype=torch.float32)
anchor_loss = None
if len(self.anchor_pairs) > 0:
# get a random anchor pair
anchor: EncodedAnchor = self.anchor_pairs[
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = get_noise_pred(
anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
).to("cpu", dtype=torch.float32)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult * rand_weight
anchor_pred_noise = get_noise_pred(
anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
).to("cpu", dtype=torch.float32)
self.network.multiplier = prompt_pair.multiplier * rand_weight
with self.network:
self.network.multiplier = prompt_pair.multiplier * rand_weight
target_latents = get_noise_pred(
prompt_pair.positive_target,
prompt_pair.target_class,
1,
current_timestep,
denoised_latents
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose:
# self.print("target_latents:", target_latents[0, 0, :5, :5])
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
if len(self.anchor_pairs) > 0:
anchor_target_noise.requires_grad = False
anchor_loss = loss_function(
anchor_target_noise,
anchor_pred_noise,
) )
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE unconditional_latents.requires_grad = False
guidance_scale = 1.0 unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
offset = guidance_scale * (positive_latents - unconditional_latents) flush() # 4.2GB to 3GB on 512x512
offset_neutral = neutral_latents # 4.20 GB RAM for 512x512
if erase: anchor_loss_float = None
offset_neutral -= offset if len(self.anchor_pairs) > 0:
else: with torch.no_grad():
# enhance # get a random anchor pair
offset_neutral += offset anchor: EncodedAnchor = self.anchor_pairs[
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
anchor.to(self.device_torch, dtype=dtype)
loss = loss_function( # first we get the target prediction without network active
target_latents, anchor_target_noise = get_noise_pred(
offset_neutral, anchor.neg_prompt, anchor.prompt, 1, current_timestep, denoised_latents
) * weight # ).to("cpu", dtype=torch.float32)
).requires_grad_(False)
loss_slide = loss.item() # to save vram, we will run these through separately while tracking grads
# otherwise it consumes a ton of vram and this isn't our speed bottleneck
anchor_chunks = split_anchors(anchor, self.prompt_chunk_size)
anchor_target_noise_chunks = torch.chunk(anchor_target_noise, self.prompt_chunk_size, dim=0)
assert len(anchor_chunks) == len(denoised_latent_chunks)
if anchor_loss is not None: # 4.32 GB RAM for 512x512
loss += anchor_loss with self.network:
assert self.network.is_active
anchor_float_losses = []
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
):
self.network.multiplier = anchor_chunk.multiplier_list
loss_float = loss.item() anchor_pred_noise = get_noise_pred(
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
)
# 9.42 GB RAM for 512x512 -> 4.20 GB RAM for 512x512 with new grad_checkpointing
anchor_loss = loss_function(
anchor_target_noise_chunk,
anchor_pred_noise,
)
anchor_float_losses.append(anchor_loss.item())
# compute anchor loss gradients
# we will accumulate them later
# this saves a ton of memory doing them separately
anchor_loss.backward()
del anchor_pred_noise
del anchor_target_noise_chunk
del anchor_loss
flush()
loss = loss.to(self.device_torch) anchor_loss_float = sum(anchor_float_losses) / len(anchor_float_losses)
del anchor_chunks
del anchor_target_noise_chunks
del anchor_target_noise
# move anchor back to cpu
anchor.to("cpu")
flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
# 3.28 GB RAM for 512x512
with self.network:
assert self.network.is_active
loss_list = []
for prompt_pair_chunk, \
denoised_latent_chunk, \
positive_latents_chunk, \
neutral_latents_chunk, \
unconditional_latents_chunk \
in zip(
prompt_pair_chunks,
denoised_latent_chunks,
positive_latents_chunks,
neutral_latents_chunks,
unconditional_latents_chunks,
):
self.network.multiplier = prompt_pair_chunk.multiplier_list
target_latents = get_noise_pred(
prompt_pair_chunk.positive_target,
prompt_pair_chunk.target_class,
1,
current_timestep,
denoised_latent_chunk
)
guidance_scale = 1.0
offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk)
# make offset multiplier based on actions
offset_multiplier_list = []
for action in prompt_pair_chunk.action_list:
if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE:
offset_multiplier_list += [-1.0]
elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE:
offset_multiplier_list += [1.0]
offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype)
# make offset multiplier match rank of offset
offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1)
offset *= offset_multiplier
offset_neutral = neutral_latents_chunk
# offsets are already adjusted on a per-batch basis
offset_neutral += offset
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = loss_function(
target_latents,
offset_neutral,
) * prompt_pair_chunk.weight
loss.backward()
loss_list.append(loss.item())
del target_latents
del offset_neutral
del loss
flush()
loss.backward()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
loss_float = sum(loss_list) / len(loss_list)
if anchor_loss_float is not None:
loss_float += anchor_loss_float
del ( del (
positive_latents, positive_latents,
neutral_latents, neutral_latents,
unconditional_latents, unconditional_latents,
target_latents, latents
latents,
) )
# move back to cpu # move back to cpu
prompt_pair.to("cpu") prompt_pair.to("cpu")
@@ -530,9 +435,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
loss_dict = OrderedDict( loss_dict = OrderedDict(
{'loss': loss_float}, {'loss': loss_float},
) )
if anchor_loss is not None: if anchor_loss_float is not None:
loss_dict['sl_l'] = loss_slide loss_dict['sl_l'] = loss_float
loss_dict['an_l'] = anchor_loss.item() loss_dict['an_l'] = anchor_loss_float
return loss_dict return loss_dict
# end hook_train_loop # end hook_train_loop

View File

@@ -108,6 +108,7 @@ class SliderConfig:
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
class GenerateImageConfig: class GenerateImageConfig:

View File

@@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from torch.utils.checkpoint import checkpoint
class ReductionKernel(nn.Module): class ReductionKernel(nn.Module):
@@ -29,3 +30,15 @@ class ReductionKernel(nn.Module):
def forward(self, x): def forward(self, x):
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
class CheckpointGradients(nn.Module):
def __init__(self, is_gradient_checkpointing=True):
super(CheckpointGradients, self).__init__()
self.is_gradient_checkpointing = is_gradient_checkpointing
def forward(self, module, *args, num_chunks=1):
if self.is_gradient_checkpointing:
return checkpoint(module, *args, num_chunks=self.num_chunks)
else:
return module(*args)

View File

@@ -1,4 +1,6 @@
import math
import os import os
import re
import sys import sys
from typing import List, Optional, Dict, Type, Union from typing import List, Optional, Dict, Type, Union
@@ -9,7 +11,170 @@ from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT) sys.path.append(SD_SCRIPTS_ROOT)
from networks.lora import LoRANetwork, LoRAModule, get_block_index from networks.lora import LoRANetwork, get_block_index
from torch.utils.checkpoint import checkpoint
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier: Union[float, List[float]] = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.is_checkpointing = False
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
# really only useful for slider training for now
def get_multiplier(self, lora_up):
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is liekly a batch size increase, so we need to
# interleve the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor
else:
return self.multiplier
def _call_forward(self, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
multiplier = self.get_multiplier(lx)
return lx * multiplier * scale
def create_custom_forward(self):
def custom_forward(*inputs):
return self._call_forward(*inputs)
return custom_forward
def forward(self, x):
org_forwarded = self.org_forward(x)
# TODO this just loses the grad. Not sure why. Probably why no one else is doing it either
# if torch.is_grad_enabled() and self.is_checkpointing and self.training:
# lora_output = checkpoint(
# self.create_custom_forward(),
# x,
# )
# else:
# lora_output = self._call_forward(x)
lora_output = self._call_forward(x)
return org_forwarded + lora_output
def enable_gradient_checkpointing(self):
self.is_checkpointing = True
def disable_gradient_checkpointing(self):
self.is_checkpointing = False
class LoRASpecialNetwork(LoRANetwork): class LoRASpecialNetwork(LoRANetwork):
@@ -70,6 +235,7 @@ class LoRASpecialNetwork(LoRANetwork):
self.dropout = dropout self.dropout = dropout
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.is_checkpointing = False
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
@@ -236,14 +402,11 @@ class LoRASpecialNetwork(LoRANetwork):
torch.save(state_dict, file) torch.save(state_dict, file)
@property @property
def multiplier(self): def multiplier(self) -> Union[float, List[float]]:
return self._multiplier return self._multiplier
@multiplier.setter @multiplier.setter
def multiplier(self, value): def multiplier(self, value: Union[float, List[float]]):
# only update if changed
if self._multiplier == value:
return
self._multiplier = value self._multiplier = value
self._update_lora_multiplier() self._update_lora_multiplier()
@@ -264,6 +427,8 @@ class LoRASpecialNetwork(LoRANetwork):
for lora in self.text_encoder_loras: for lora in self.text_encoder_loras:
lora.multiplier = 0 lora.multiplier = 0
# called when the context manager is entered
# ie: with network:
def __enter__(self): def __enter__(self):
self.is_active = True self.is_active = True
self._update_lora_multiplier() self._update_lora_multiplier()
@@ -281,3 +446,29 @@ class LoRASpecialNetwork(LoRANetwork):
loras += self.text_encoder_loras loras += self.text_encoder_loras
for lora in loras: for lora in loras:
lora.to(device, dtype) lora.to(device, dtype)
def _update_checkpointing(self):
if self.is_checkpointing:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.enable_gradient_checkpointing()
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.enable_gradient_checkpointing()
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.disable_gradient_checkpointing()
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.disable_gradient_checkpointing()
def enable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = True
self._update_checkpointing()
def disable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = False
self._update_checkpointing()

387
toolkit/prompt_utils.py Normal file
View File

@@ -0,0 +1,387 @@
import os
from typing import Optional, TYPE_CHECKING, List
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
class ACTION_TYPES_SLIDER:
ERASE_NEGATIVE = 0
ENHANCE_NEGATIVE = 1
class EncodedPromptPair:
def __init__(
self,
target_class,
target_class_with_neutral,
positive_target,
positive_target_with_neutral,
negative_target,
negative_target_with_neutral,
neutral,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
action_list=None,
multiplier=1.0,
multiplier_list=None,
weight=1.0
):
self.target_class: PromptEmbeds = target_class
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
self.positive_target: PromptEmbeds = positive_target
self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral
self.negative_target: PromptEmbeds = negative_target
self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral
self.neutral: PromptEmbeds = neutral
self.empty_prompt: PromptEmbeds = empty_prompt
self.both_targets: PromptEmbeds = both_targets
self.multiplier: float = multiplier
if multiplier_list is not None:
self.multiplier_list: list[float] = multiplier_list
else:
self.multiplier_list: list[float] = [multiplier]
self.action: int = action
if action_list is not None:
self.action_list: list[int] = action_list
else:
self.action_list: list[int] = [action]
self.weight: float = weight
# simulate torch to for tensors
def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs)
self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
self.neutral = self.neutral.to(*args, **kwargs)
self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])
def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
weight = prompt_pairs[0].weight
target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs])
target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs])
positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs])
positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs])
negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs])
negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs])
neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs])
empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs])
both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs])
# combine all the lists
action_list = []
multiplier_list = []
weight_list = []
for p in prompt_pairs:
action_list += p.action_list
multiplier_list += p.multiplier_list
return EncodedPromptPair(
target_class=target_class,
target_class_with_neutral=target_class_with_neutral,
positive_target=positive_target,
positive_target_with_neutral=positive_target_with_neutral,
negative_target=negative_target,
negative_target_with_neutral=negative_target_with_neutral,
neutral=neutral,
empty_prompt=empty_prompt,
both_targets=both_targets,
action_list=action_list,
multiplier_list=multiplier_list,
weight=weight
)
def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]:
if num_parts is None:
# use batch size
num_parts = concatenated.text_embeds.shape[0]
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if concatenated.pooled_embeds is not None:
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)
else:
pooled_embeds_splits = [None] * num_parts
prompt_embeds_list = [
PromptEmbeds([text, pooled])
for text, pooled in zip(text_embeds_splits, pooled_embeds_splits)
]
return prompt_embeds_list
def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]:
target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds)
target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds)
positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds)
positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds)
negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds)
negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds)
neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds)
empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds)
both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds)
prompt_pairs = []
for i in range(len(target_class_splits)):
action_list_split = concatenated.action_list[i::len(target_class_splits)]
multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)]
prompt_pair = EncodedPromptPair(
target_class=target_class_splits[i],
target_class_with_neutral=target_class_with_neutral_splits[i],
positive_target=positive_target_splits[i],
positive_target_with_neutral=positive_target_with_neutral_splits[i],
negative_target=negative_target_splits[i],
negative_target_with_neutral=negative_target_with_neutral_splits[i],
neutral=neutral_splits[i],
empty_prompt=empty_prompt_splits[i],
both_targets=both_targets_splits[i],
action_list=action_list_split,
multiplier_list=multiplier_list_split,
weight=concatenated.weight
)
prompt_pairs.append(prompt_pair)
return prompt_pairs
class PromptEmbedsCache:
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0,
multiplier_list=None
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier
if multiplier_list is not None:
self.multiplier_list: list[float] = multiplier_list
else:
self.multiplier_list: list[float] = [multiplier]
def to(self, *args, **kwargs):
self.prompt = self.prompt.to(*args, **kwargs)
self.neg_prompt = self.neg_prompt.to(*args, **kwargs)
return self
def concat_anchors(anchors: list[EncodedAnchor]):
prompt = concat_prompt_embeds([a.prompt for a in anchors])
neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors])
return EncodedAnchor(
prompt=prompt,
neg_prompt=neg_prompt,
multiplier_list=[a.multiplier for a in anchors]
)
def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]:
prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors)
neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors)
multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors)
anchors = []
for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits):
anchor = EncodedAnchor(
prompt=prompt,
neg_prompt=neg_prompt,
multiplier=multiplier.tolist()
)
anchors.append(anchor)
return anchors
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
@torch.no_grad()
def encode_prompts_to_cache(
prompt_list: list[str],
sd: "StableDiffusion",
cache: Optional[PromptEmbedsCache] = None,
prompt_tensor_file: Optional[str] = None,
) -> PromptEmbedsCache:
# TODO: add support for larger prompts
if cache is None:
cache = PromptEmbedsCache()
if prompt_tensor_file is not None:
# check to see if it exists
if os.path.exists(prompt_tensor_file):
# load it.
print(f"Loading prompt tensors from {prompt_tensor_file}")
prompt_tensors = load_file(prompt_tensor_file, device='cpu')
# add them to the cache
for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
if prompt_txt.startswith("te:"):
prompt = prompt_txt[3:]
# text_embeds
text_embeds = prompt_tensor
pooled_embeds = None
# find pool embeds
if f"pe:{prompt}" in prompt_tensors:
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
# make it
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
if len(cache.prompts) == 0:
print("Prompt tensors not found. Encoding prompts..")
empty_prompt = ""
# encode empty_prompt
cache[empty_prompt] = sd.encode_prompt(empty_prompt)
for p in tqdm(prompt_list, desc="Encoding prompts", leave=False):
# build the cache
if cache[p] is None:
cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16)
# should we shard? It can get large
if prompt_tensor_file:
print(f"Saving prompt tensors to {prompt_tensor_file}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to(
"cpu", dtype=get_torch_dtype('fp16')
)
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to(
"cpu",
dtype=get_torch_dtype('fp16')
)
save_file(state_dict, prompt_tensor_file)
return cache
if TYPE_CHECKING:
from toolkit.config_modules import SliderTargetConfig
@torch.no_grad()
def build_prompt_pair_batch_from_cache(
cache: PromptEmbedsCache,
target: 'SliderTargetConfig',
neutral: Optional[str] = '',
) -> list[EncodedPromptPair]:
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
prompt_pair_batch = []
if both or erase_negative:
print("Encoding erase negative")
prompt_pair_batch += [
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
if both or enhance_positive:
print("Encoding enhance positive")
prompt_pair_batch += [
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
]
if both or enhance_positive:
print("Encoding erase positive (inverse)")
prompt_pair_batch += [
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.negative}"],
positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
negative_target=cache[f"{target.positive}"],
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
if both or erase_negative:
print("Encoding enhance negative (inverse)")
prompt_pair_batch += [
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
positive_target=cache[f"{target.positive}"],
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
both_targets=cache[f"{target.positive} {target.negative}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
return prompt_pair_batch

View File

@@ -1,6 +1,6 @@
import gc import gc
import typing import typing
from typing import Union, OrderedDict, List from typing import Union, OrderedDict, List, Tuple
import sys import sys
import os import os
@@ -50,10 +50,10 @@ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class PromptEmbeds: class PromptEmbeds:
text_embeds: torch.FloatTensor text_embeds: torch.Tensor
pooled_embeds: Union[torch.FloatTensor, None] pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args) -> None: def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple): if isinstance(args, list) or isinstance(args, tuple):
# xl # xl
self.text_embeds = args[0] self.text_embeds = args[0]
@@ -139,12 +139,23 @@ class StableDiffusion:
pipln = self.custom_pipeline pipln = self.custom_pipeline
else: else:
pipln = CustomStableDiffusionXLPipeline pipln = CustomStableDiffusionXLPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path, # see if path exists
dtype=dtype, if not os.path.exists(self.model_config.name_or_path):
scheduler_type='ddpm', # try to load with default diffusers
device=self.device_torch, pipe = pipln.from_pretrained(
).to(self.device_torch) self.model_config.name_or_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
else:
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
@@ -158,14 +169,27 @@ class StableDiffusion:
pipln = self.custom_pipeline pipln = self.custom_pipeline
else: else:
pipln = CustomStableDiffusionPipeline pipln = CustomStableDiffusionPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path, # see if path exists
dtype=dtype, if not os.path.exists(self.model_config.name_or_path):
scheduler_type='dpm', # try to load with default diffusers
device=self.device_torch, pipe = pipln.from_pretrained(
load_safety_checker=False, self.model_config.name_or_path,
requires_safety_checker=False, dtype=dtype,
).to(self.device_torch) scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
requires_safety_checker=False,
).to(self.device_torch)
else:
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
requires_safety_checker=False,
).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False) pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype)