mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added anchors to regulate the lora
This commit is contained in:
@@ -2,6 +2,7 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from leco.train_util import predict_noise
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -59,11 +60,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.sd = None
|
||||
self.sd: 'StableDiffusion' = None
|
||||
|
||||
# added later
|
||||
self.network = None
|
||||
self.scheduler = None
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -118,7 +118,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'multiplier': self.network.multiplier,
|
||||
})
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
|
||||
neg = self.sample_config.neg
|
||||
@@ -267,6 +267,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
def diffuse_some_steps(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
total_timesteps: int = 1000,
|
||||
start_timesteps=0,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@@ -368,7 +389,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
### HOOK ###
|
||||
loss = self.hook_train_loop()
|
||||
loss_dict = self.hook_train_loop()
|
||||
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
@@ -386,15 +421,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
self.writer.add_scalar(f"loss", loss, self.step_num)
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user