Added anchors to regulate the lora

This commit is contained in:
Jaret Burkett
2023-07-24 14:59:16 -06:00
parent 390192c6a1
commit 61dd818608
5 changed files with 180 additions and 31 deletions

View File

@@ -104,8 +104,18 @@ Just went in and out. It is much worse on smaller faces than shown here.
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
---
## TODO
- [ ] Add proper regs on sliders
- [X] Add proper regs on sliders
- [ ] Add SDXL support (base model only for now)
- [ ] Add plain erasing
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
---
## Change Log
#### 2021-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights

View File

@@ -163,6 +163,25 @@ config:
# to a lower number like 0.1 so they dont outweigh the primary target
weight: 1.0
# anchors are prompts that wer try to hold on to while training the slider
# you want these to generate an image very similar to the target_class
# without directly overlapping it. For example, if you are training on a person smiling,
# you would use "a person with a face mask" as an anchor. It is a person, the image is the same
# regardless if they are smiling or not
anchors:
# only positive prompt for now
- prompt: "a woman"
neg_prompt: "animal"
# the multiplier applied to the LoRA when this is run.
# higher will give it more weight but also help keep the lora from collapsing
multiplier: 8.0
- prompt: "a man"
neg_prompt: "animal"
multiplier: 8.0
- prompt: "a person"
neg_prompt: "animal"
multiplier: 8.0
# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this

View File

@@ -2,6 +2,7 @@ import time
from collections import OrderedDict
import os
from leco.train_util import predict_noise
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -59,11 +60,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
self.sd = None
self.sd: 'StableDiffusion' = None
# added later
self.network = None
self.scheduler = None
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -118,7 +118,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
raw_prompt = self.sample_config.prompts[i]
neg = self.sample_config.neg
@@ -267,6 +267,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor,
total_timesteps: int = 1000,
start_timesteps=0,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = train_util.predict_noise(
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
def run(self):
super().run()
@@ -368,7 +389,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
# todo handle dataloader here maybe, not sure
### HOOK ###
loss = self.hook_train_loop()
loss_dict = self.hook_train_loop()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
self.progress_bar.set_postfix_str(prog_bar_string)
# don't do on first step
if self.step_num != self.start_step:
@@ -386,15 +421,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
self.writer.add_scalar(f"loss", loss, self.step_num)
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()

View File

@@ -29,6 +29,7 @@ def flush():
gc.collect()
class EncodedPromptPair:
def __init__(
self,
@@ -53,6 +54,18 @@ class EncodedPromptPair:
self.weight = weight
class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier
class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -61,9 +74,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.slider_config = SliderConfig(**self.get_conf('slider', {}))
self.prompt_cache = PromptEmbedsCache()
self.prompt_pairs: list[EncodedPromptPair] = []
self.anchor_pairs: list[EncodedAnchor] = []
def before_model_load(self):
pass
@@ -146,16 +159,39 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
]
# setup anchors
anchor_pairs = []
for anchor in self.slider_config.anchors:
# build the cache
for prompt in [
anchor.prompt,
anchor.neg_prompt # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
anchor_pairs += [
EncodedAnchor(
prompt=cache[anchor.prompt],
neg_prompt=cache[anchor.neg_prompt],
multiplier=anchor.multiplier
)
]
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs
flush()
# end hook_before_train_loop
def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
@@ -202,10 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
# A little denoised one is returned
denoised_latents = train_util.diffusion(
unet,
noise_scheduler,
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_util.concat_embeddings(
positive, # unconditional
@@ -261,7 +294,46 @@ class TrainSliderProcess(BaseSDTrainProcess):
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_loss = None
if len(self.anchor_pairs) > 0:
# get a random anchor pair
anchor: EncodedAnchor = self.anchor_pairs[
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
self.network.multiplier = prompt_pair.multiplier
with self.network:
self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
@@ -281,7 +353,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
if len(self.anchor_pairs) > 0:
anchor_target_noise.requires_grad = False
anchor_loss = loss_function(
anchor_target_noise,
anchor_pred_noise,
)
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
guidance_scale = 1.0
@@ -299,16 +376,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
offset_neutral,
) * weight
loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
loss_slide = loss.item()
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
if anchor_loss is not None:
loss += anchor_loss
loss_float = loss.item()
loss = loss.to(self.device_torch)
loss.backward()
optimizer.step()
@@ -326,5 +401,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# reset network
self.network.multiplier = 1.0
return loss_float
loss_dict = OrderedDict(
{'loss': loss_float},
)
if anchor_loss is not None:
loss_dict['sl_l'] = loss_slide
loss_dict['an_l'] = anchor_loss.item()
return loss_dict
# end hook_train_loop

View File

@@ -71,9 +71,19 @@ class SliderTargetConfig:
self.weight: float = kwargs.get('weight', 1.0)
class SliderConfigAnchors:
def __init__(self, **kwargs):
self.prompt = kwargs.get('prompt', '')
self.neg_prompt = kwargs.get('neg_prompt', '')
self.multiplier = kwargs.get('multiplier', 1.0)
class SliderConfig:
def __init__(self, **kwargs):
targets = kwargs.get('targets', [])
targets = [SliderTargetConfig(**target) for target in targets]
self.targets: List[SliderTargetConfig] = targets
anchors = kwargs.get('anchors', [])
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
self.anchors: List[SliderConfigAnchors] = anchors
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])