From 79b4e04b80103628870de7e2566c8143d4fe902d Mon Sep 17 00:00:00 2001 From: Plat <60182057+p1atdev@users.noreply.github.com> Date: Fri, 20 Sep 2024 11:01:01 +0900 Subject: [PATCH] Feat: Wandb logging (#95) * wandb logging * fix: start logging before train loop * chore: add wandb dir to gitignore * fix: wrap wandb functions * fix: forget to send last samples * chore: use valid type * chore: use None when not type-checking * chore: resolved complicated logic * fix: follow log_every --------- Co-authored-by: Plat Co-authored-by: Jaret Burkett --- .gitignore | 3 +- extensions_built_in/sd_trainer/SDTrainer.py | 2 + jobs/process/BaseSDTrainProcess.py | 36 ++++++++- toolkit/config_modules.py | 17 ++++- toolkit/logging.py | 84 +++++++++++++++++++++ toolkit/stable_diffusion_model.py | 1 + 6 files changed, 136 insertions(+), 7 deletions(-) create mode 100644 toolkit/logging.py diff --git a/.gitignore b/.gitignore index edb8d50d..9e03d70f 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,5 @@ cython_debug/ !/output/.gitkeep /extensions/* !/extensions/example -/temp \ No newline at end of file +/temp +/wandb diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e2e62b94..5b615f27 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -113,6 +113,8 @@ class SDTrainer(BaseSDTrainProcess): self.taesd.requires_grad_(False) def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.train_config.do_prior_divergence: self.do_prior_prediction = True # move vae to device if we did not cache latents diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index c50a5cdc..de812e9b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -55,9 +55,9 @@ import gc from tqdm import tqdm -from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ +from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig - +from toolkit.logging import create_logger def flush(): torch.cuda.empty_cache() @@ -102,7 +102,8 @@ class BaseSDTrainProcess(BaseTrainProcess): else: self.has_first_sample_requested = False self.first_sample_config = self.sample_config - self.logging_config = LogingConfig(**self.get_conf('logging', {})) + self.logging_config = LoggingConfig(**self.get_conf('logging', {})) + self.logger = create_logger(self.logging_config, config) self.optimizer: torch.optim.Optimizer = None self.lr_scheduler = None self.data_loader: Union[DataLoader, None] = None @@ -258,6 +259,7 @@ class BaseSDTrainProcess(BaseTrainProcess): adapter_conditioning_scale=sample_config.adapter_conditioning_scale, refiner_start_at=sample_config.refiner_start_at, extra_values=sample_config.extra_values, + logger=self.logger, **extra_args )) @@ -568,7 +570,7 @@ class BaseSDTrainProcess(BaseTrainProcess): return params def hook_before_train_loop(self): - pass + self.logger.start() def ensure_params_requires_grad(self): # get param groups @@ -1627,6 +1629,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # TRAIN LOOP ################################################################### + start_step_num = self.step_num did_first_flush = False for step in range(start_step_num, self.train_config.steps): @@ -1767,6 +1770,25 @@ class BaseSDTrainProcess(BaseTrainProcess): self.writer.add_scalar(f"{key}", value, self.step_num) self.writer.add_scalar(f"lr", learning_rate, self.step_num) self.progress_bar.unpause() + + # log to logger + self.logger.log({ + 'learning_rate': learning_rate, + }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + elif self.logging_config.log_every is None: + # log every step + self.logger.log({ + 'learning_rate': learning_rate, + }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0: self.progress_bar.pause() @@ -1774,6 +1796,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.timer.print() self.timer.reset() self.progress_bar.unpause() + + # commit log + self.logger.commit(step=self.step_num) # sets progress bar to match out step self.progress_bar.update(step - self.progress_bar.n) @@ -1796,8 +1821,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.pipeline.disable_freeu() if not self.train_config.disable_sampling: self.sample(self.step_num) + self.logger.commit(step=self.step_num) print("") self.save() + self.logger.finish() + if self.save_config.push_to_hub: if("HF_TOKEN" not in os.environ): interpreter_login(new_session=False, write_permission=True) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b8c82c8e..a8eb35fb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -13,7 +13,9 @@ SaveFormat = Literal['safetensors', 'diffusers'] if TYPE_CHECKING: from toolkit.guidance import GuidanceType - + from toolkit.logging import EmptyLogger +else: + EmptyLogger = None class SaveConfig: def __init__(self, **kwargs): @@ -27,11 +29,13 @@ class SaveConfig: self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None) self.hf_private: Optional[str] = kwargs.get("hf_private", False) -class LogingConfig: +class LoggingConfig: def __init__(self, **kwargs): self.log_every: int = kwargs.get('log_every', 100) self.verbose: bool = kwargs.get('verbose', False) self.use_wandb: bool = kwargs.get('use_wandb', False) + self.project_name: str = kwargs.get('project_name', 'ai-toolkit') + self.run_name: str = kwargs.get('run_name', None) class SampleConfig: @@ -655,6 +659,7 @@ class GenerateImageConfig: extra_kwargs: dict = None, # extra data to save with prompt file refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end extra_values: List[float] = None, # extra values to save with prompt file + logger: Optional[EmptyLogger] = None, ): self.width: int = width self.height: int = height @@ -712,6 +717,8 @@ class GenerateImageConfig: self.height = max(64, self.height - self.height % 8) # round to divisible by 8 self.width = max(64, self.width - self.width % 8) # round to divisible by 8 + self.logger = logger + def set_gen_time(self, gen_time: int = None): if gen_time is not None: self.gen_time = gen_time @@ -855,3 +862,9 @@ class GenerateImageConfig: ): # this is called after prompt embeds are encoded. We can override them in the future here pass + + def log_image(self, image, count: int = 0, max_count=0): + if self.logger is None: + return + + self.logger.log_image(image, count, self.prompt) \ No newline at end of file diff --git a/toolkit/logging.py b/toolkit/logging.py new file mode 100644 index 00000000..56b1c8b5 --- /dev/null +++ b/toolkit/logging.py @@ -0,0 +1,84 @@ +from typing import OrderedDict, Optional +from PIL import Image + +from toolkit.config_modules import LoggingConfig + +# Base logger class +# This class does nothing, it's just a placeholder +class EmptyLogger: + def __init__(self, *args, **kwargs) -> None: + pass + + # start logging the training + def start(self): + pass + + # collect the log to send + def log(self, *args, **kwargs): + pass + + # send the log + def commit(self, step: Optional[int] = None): + pass + + # log image + def log_image(self, *args, **kwargs): + pass + + # finish logging + def finish(self): + pass + +# Wandb logger class +# This class logs the data to wandb +class WandbLogger(EmptyLogger): + def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None: + self.project = project + self.run_name = run_name + self.config = config + + def start(self): + try: + import wandb + except ImportError: + raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`") + + # send the whole config to wandb + run = wandb.init(project=self.project, name=self.run_name, config=self.config) + self.run = run + self._log = wandb.log # log function + self._image = wandb.Image # image object + + def log(self, *args, **kwargs): + # when commit is False, wandb increments the step, + # but we don't want that to happen, so we set commit=False + self._log(*args, **kwargs, commit=False) + + def commit(self, step: Optional[int] = None): + # after overall one step is done, we commit the log + # by log empty object with commit=True + self._log({}, step=step, commit=True) + + def log_image( + self, + image: Image, + id, # sample index + caption: str | None = None, # positive prompt + *args, + **kwargs, + ): + # create a wandb image object and log it + image = self._image(image, caption=caption, *args, **kwargs) + self._log({f"sample_{id}": image}, commit=False) + + def finish(self): + self.run.finish() + +# create logger based on the logging config +def create_logger(logging_config: LoggingConfig, all_config: OrderedDict): + if logging_config.use_wandb: + project_name = logging_config.project_name + run_name = logging_config.run_name + return WandbLogger(project=project_name, run_name=run_name, config=all_config) + else: + return EmptyLogger() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 651a92e6..b07fbc68 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1318,6 +1318,7 @@ class StableDiffusion: ).images[0] gen_config.save_image(img, i) + gen_config.log_image(img, i) if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory()