Added anchors to regulate the lora

This commit is contained in:
Jaret Burkett
2023-07-24 14:59:16 -06:00
parent 390192c6a1
commit 61dd818608
5 changed files with 180 additions and 31 deletions

View File

@@ -29,6 +29,7 @@ def flush():
gc.collect()
class EncodedPromptPair:
def __init__(
self,
@@ -53,6 +54,18 @@ class EncodedPromptPair:
self.weight = weight
class EncodedAnchor:
def __init__(
self,
prompt,
neg_prompt,
multiplier=1.0
):
self.prompt = prompt
self.neg_prompt = neg_prompt
self.multiplier = multiplier
class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -61,9 +74,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.slider_config = SliderConfig(**self.get_conf('slider', {}))
self.prompt_cache = PromptEmbedsCache()
self.prompt_pairs: list[EncodedPromptPair] = []
self.anchor_pairs: list[EncodedAnchor] = []
def before_model_load(self):
pass
@@ -146,16 +159,39 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
]
# setup anchors
anchor_pairs = []
for anchor in self.slider_config.anchors:
# build the cache
for prompt in [
anchor.prompt,
anchor.neg_prompt # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
anchor_pairs += [
EncodedAnchor(
prompt=cache[anchor.prompt],
neg_prompt=cache[anchor.neg_prompt],
multiplier=anchor.multiplier
)
]
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs
flush()
# end hook_before_train_loop
def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
@@ -202,10 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
# A little denoised one is returned
denoised_latents = train_util.diffusion(
unet,
noise_scheduler,
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_util.concat_embeddings(
positive, # unconditional
@@ -261,7 +294,46 @@ class TrainSliderProcess(BaseSDTrainProcess):
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_loss = None
if len(self.anchor_pairs) > 0:
# get a random anchor pair
anchor: EncodedAnchor = self.anchor_pairs[
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
self.network.multiplier = prompt_pair.multiplier
with self.network:
self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
@@ -281,7 +353,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
if len(self.anchor_pairs) > 0:
anchor_target_noise.requires_grad = False
anchor_loss = loss_function(
anchor_target_noise,
anchor_pred_noise,
)
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
guidance_scale = 1.0
@@ -299,16 +376,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
offset_neutral,
) * weight
loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
loss_slide = loss.item()
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
if anchor_loss is not None:
loss += anchor_loss
loss_float = loss.item()
loss = loss.to(self.device_torch)
loss.backward()
optimizer.step()
@@ -326,5 +401,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# reset network
self.network.multiplier = 1.0
return loss_float
loss_dict = OrderedDict(
{'loss': loss_float},
)
if anchor_loss is not None:
loss_dict['sl_l'] = loss_slide
loss_dict['an_l'] = anchor_loss.item()
return loss_dict
# end hook_train_loop