Feat: Wandb logging (#95)

* wandb logging

* fix: start logging before train loop

* chore: add wandb dir to gitignore

* fix: wrap wandb functions

* fix: forget to send last samples

* chore: use valid type

* chore: use None when not type-checking

* chore: resolved complicated logic

* fix: follow log_every

---------

Co-authored-by: Plat <github@p1at.dev>
Co-authored-by: Jaret Burkett <jaretburkett@gmail.com>
This commit is contained in:
Plat
2024-09-20 11:01:01 +09:00
committed by GitHub
parent 951e223481
commit 79b4e04b80
6 changed files with 136 additions and 7 deletions

View File

@@ -55,9 +55,9 @@ import gc
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig
from toolkit.logging import create_logger
def flush():
torch.cuda.empty_cache()
@@ -102,7 +102,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
self.has_first_sample_requested = False
self.first_sample_config = self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.logging_config = LoggingConfig(**self.get_conf('logging', {}))
self.logger = create_logger(self.logging_config, config)
self.optimizer: torch.optim.Optimizer = None
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
@@ -258,6 +259,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
refiner_start_at=sample_config.refiner_start_at,
extra_values=sample_config.extra_values,
logger=self.logger,
**extra_args
))
@@ -568,7 +570,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
return params
def hook_before_train_loop(self):
pass
self.logger.start()
def ensure_params_requires_grad(self):
# get param groups
@@ -1627,6 +1629,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# TRAIN LOOP
###################################################################
start_step_num = self.step_num
did_first_flush = False
for step in range(start_step_num, self.train_config.steps):
@@ -1767,6 +1770,25 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.unpause()
# log to logger
self.logger.log({
'learning_rate': learning_rate,
})
for key, value in loss_dict.items():
self.logger.log({
f'loss/{key}': value,
})
elif self.logging_config.log_every is None:
# log every step
self.logger.log({
'learning_rate': learning_rate,
})
for key, value in loss_dict.items():
self.logger.log({
f'loss/{key}': value,
})
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
self.progress_bar.pause()
@@ -1774,6 +1796,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.timer.print()
self.timer.reset()
self.progress_bar.unpause()
# commit log
self.logger.commit(step=self.step_num)
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
@@ -1796,8 +1821,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.pipeline.disable_freeu()
if not self.train_config.disable_sampling:
self.sample(self.step_num)
self.logger.commit(step=self.step_num)
print("")
self.save()
self.logger.finish()
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)