Added anchors to regulate the lora

This commit is contained in:
Jaret Burkett
2023-07-24 14:59:16 -06:00
parent 390192c6a1
commit 61dd818608
5 changed files with 180 additions and 31 deletions

View File

@@ -2,6 +2,7 @@ import time
from collections import OrderedDict
import os
from leco.train_util import predict_noise
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -59,11 +60,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
self.sd = None
self.sd: 'StableDiffusion' = None
# added later
self.network = None
self.scheduler = None
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -118,7 +118,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
raw_prompt = self.sample_config.prompts[i]
neg = self.sample_config.neg
@@ -267,6 +267,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor,
total_timesteps: int = 1000,
start_timesteps=0,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = train_util.predict_noise(
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
def run(self):
super().run()
@@ -368,7 +389,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
# todo handle dataloader here maybe, not sure
### HOOK ###
loss = self.hook_train_loop()
loss_dict = self.hook_train_loop()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
self.progress_bar.set_postfix_str(prog_bar_string)
# don't do on first step
if self.step_num != self.start_step:
@@ -386,15 +421,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
self.writer.add_scalar(f"loss", loss, self.step_num)
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()