mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 13:23:56 +00:00
Added anchors to regulate the lora
This commit is contained in:
14
README.md
14
README.md
@@ -104,8 +104,18 @@ Just went in and out. It is much worse on smaller faces than shown here.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
|
||||
|
||||
---
|
||||
|
||||
## TODO
|
||||
- [ ] Add proper regs on sliders
|
||||
- [X] Add proper regs on sliders
|
||||
- [ ] Add SDXL support (base model only for now)
|
||||
- [ ] Add plain erasing
|
||||
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
||||
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
||||
|
||||
---
|
||||
|
||||
## Change Log
|
||||
#### 2021-07-30
|
||||
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
|
||||
regularizer. You can set the network multiplier to force spread consistency at high weights
|
||||
|
||||
|
||||
@@ -163,6 +163,25 @@ config:
|
||||
# to a lower number like 0.1 so they dont outweigh the primary target
|
||||
weight: 1.0
|
||||
|
||||
# anchors are prompts that wer try to hold on to while training the slider
|
||||
# you want these to generate an image very similar to the target_class
|
||||
# without directly overlapping it. For example, if you are training on a person smiling,
|
||||
# you would use "a person with a face mask" as an anchor. It is a person, the image is the same
|
||||
# regardless if they are smiling or not
|
||||
anchors:
|
||||
# only positive prompt for now
|
||||
- prompt: "a woman"
|
||||
neg_prompt: "animal"
|
||||
# the multiplier applied to the LoRA when this is run.
|
||||
# higher will give it more weight but also help keep the lora from collapsing
|
||||
multiplier: 8.0
|
||||
- prompt: "a man"
|
||||
neg_prompt: "animal"
|
||||
multiplier: 8.0
|
||||
- prompt: "a person"
|
||||
neg_prompt: "animal"
|
||||
multiplier: 8.0
|
||||
|
||||
# You can put any information you want here, and it will be saved in the model.
|
||||
# The below is an example, but you can put your grocery list in it if you want.
|
||||
# It is saved in the model so be aware of that. The software will include this
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from leco.train_util import predict_noise
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -59,11 +60,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.sd = None
|
||||
self.sd: 'StableDiffusion' = None
|
||||
|
||||
# added later
|
||||
self.network = None
|
||||
self.scheduler = None
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -118,7 +118,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'multiplier': self.network.multiplier,
|
||||
})
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
|
||||
neg = self.sample_config.neg
|
||||
@@ -267,6 +267,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
def diffuse_some_steps(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
total_timesteps: int = 1000,
|
||||
start_timesteps=0,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@@ -368,7 +389,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
### HOOK ###
|
||||
loss = self.hook_train_loop()
|
||||
loss_dict = self.hook_train_loop()
|
||||
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
@@ -386,15 +421,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
self.writer.add_scalar(f"loss", loss, self.step_num)
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,6 +54,18 @@ class EncodedPromptPair:
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class EncodedAnchor:
|
||||
def __init__(
|
||||
self,
|
||||
prompt,
|
||||
neg_prompt,
|
||||
multiplier=1.0
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.neg_prompt = neg_prompt
|
||||
self.multiplier = multiplier
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -61,9 +74,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.slider_config = SliderConfig(**self.get_conf('slider', {}))
|
||||
|
||||
self.prompt_cache = PromptEmbedsCache()
|
||||
self.prompt_pairs: list[EncodedPromptPair] = []
|
||||
self.anchor_pairs: list[EncodedAnchor] = []
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -146,16 +159,39 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
]
|
||||
|
||||
# setup anchors
|
||||
anchor_pairs = []
|
||||
for anchor in self.slider_config.anchors:
|
||||
# build the cache
|
||||
for prompt in [
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
|
||||
anchor_pairs += [
|
||||
EncodedAnchor(
|
||||
prompt=cache[anchor.prompt],
|
||||
neg_prompt=cache[anchor.neg_prompt],
|
||||
multiplier=anchor.multiplier
|
||||
)
|
||||
]
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
self.prompt_pairs = prompt_pairs
|
||||
self.anchor_pairs = anchor_pairs
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
@@ -202,10 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# A little denoised one is returned
|
||||
denoised_latents = train_util.diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
@@ -261,7 +294,46 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
anchor_loss = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
# get a random anchor pair
|
||||
anchor: EncodedAnchor = self.anchor_pairs[
|
||||
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
||||
]
|
||||
with torch.no_grad():
|
||||
anchor_target_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
with self.network:
|
||||
# anchor whatever weight prompt pair is using
|
||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||
anchor_pred_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
target_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
@@ -281,7 +353,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
positive_latents.requires_grad = False
|
||||
neutral_latents.requires_grad = False
|
||||
unconditional_latents.requires_grad = False
|
||||
|
||||
if len(self.anchor_pairs) > 0:
|
||||
anchor_target_noise.requires_grad = False
|
||||
anchor_loss = loss_function(
|
||||
anchor_target_noise,
|
||||
anchor_pred_noise,
|
||||
)
|
||||
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
|
||||
guidance_scale = 1.0
|
||||
|
||||
@@ -299,16 +376,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
offset_neutral,
|
||||
) * weight
|
||||
|
||||
loss_float = loss.item()
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
loss_slide = loss.item()
|
||||
|
||||
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
|
||||
if anchor_loss is not None:
|
||||
loss += anchor_loss
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
loss = loss.to(self.device_torch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -326,5 +401,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
return loss_float
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
if anchor_loss is not None:
|
||||
loss_dict['sl_l'] = loss_slide
|
||||
loss_dict['an_l'] = anchor_loss.item()
|
||||
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
|
||||
@@ -71,9 +71,19 @@ class SliderTargetConfig:
|
||||
self.weight: float = kwargs.get('weight', 1.0)
|
||||
|
||||
|
||||
class SliderConfigAnchors:
|
||||
def __init__(self, **kwargs):
|
||||
self.prompt = kwargs.get('prompt', '')
|
||||
self.neg_prompt = kwargs.get('neg_prompt', '')
|
||||
self.multiplier = kwargs.get('multiplier', 1.0)
|
||||
|
||||
|
||||
class SliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
targets = kwargs.get('targets', [])
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
self.targets: List[SliderTargetConfig] = targets
|
||||
anchors = kwargs.get('anchors', [])
|
||||
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
||||
self.anchors: List[SliderConfigAnchors] = anchors
|
||||
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
|
||||
|
||||
Reference in New Issue
Block a user