mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added anchors to regulate the lora
This commit is contained in:
@@ -29,6 +29,7 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,6 +54,18 @@ class EncodedPromptPair:
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class EncodedAnchor:
|
||||
def __init__(
|
||||
self,
|
||||
prompt,
|
||||
neg_prompt,
|
||||
multiplier=1.0
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.neg_prompt = neg_prompt
|
||||
self.multiplier = multiplier
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -61,9 +74,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.slider_config = SliderConfig(**self.get_conf('slider', {}))
|
||||
|
||||
self.prompt_cache = PromptEmbedsCache()
|
||||
self.prompt_pairs: list[EncodedPromptPair] = []
|
||||
self.anchor_pairs: list[EncodedAnchor] = []
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -146,16 +159,39 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
]
|
||||
|
||||
# setup anchors
|
||||
anchor_pairs = []
|
||||
for anchor in self.slider_config.anchors:
|
||||
# build the cache
|
||||
for prompt in [
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
|
||||
anchor_pairs += [
|
||||
EncodedAnchor(
|
||||
prompt=cache[anchor.prompt],
|
||||
neg_prompt=cache[anchor.neg_prompt],
|
||||
multiplier=anchor.multiplier
|
||||
)
|
||||
]
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
self.prompt_pairs = prompt_pairs
|
||||
self.anchor_pairs = anchor_pairs
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
@@ -202,10 +238,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# A little denoised one is returned
|
||||
denoised_latents = train_util.diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
@@ -261,7 +294,46 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
anchor_loss = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
# get a random anchor pair
|
||||
anchor: EncodedAnchor = self.anchor_pairs[
|
||||
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
||||
]
|
||||
with torch.no_grad():
|
||||
anchor_target_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
with self.network:
|
||||
# anchor whatever weight prompt pair is using
|
||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||
anchor_pred_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
target_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
@@ -281,7 +353,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
positive_latents.requires_grad = False
|
||||
neutral_latents.requires_grad = False
|
||||
unconditional_latents.requires_grad = False
|
||||
|
||||
if len(self.anchor_pairs) > 0:
|
||||
anchor_target_noise.requires_grad = False
|
||||
anchor_loss = loss_function(
|
||||
anchor_target_noise,
|
||||
anchor_pred_noise,
|
||||
)
|
||||
erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
|
||||
guidance_scale = 1.0
|
||||
|
||||
@@ -299,16 +376,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
offset_neutral,
|
||||
) * weight
|
||||
|
||||
loss_float = loss.item()
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
loss_slide = loss.item()
|
||||
|
||||
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
|
||||
if anchor_loss is not None:
|
||||
loss += anchor_loss
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
loss = loss.to(self.device_torch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -326,5 +401,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
return loss_float
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
if anchor_loss is not None:
|
||||
loss_dict['sl_l'] = loss_slide
|
||||
loss_dict['an_l'] = anchor_loss.item()
|
||||
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
|
||||
Reference in New Issue
Block a user