Feat: Wandb logging (#95)

* wandb logging

* fix: start logging before train loop

* chore: add wandb dir to gitignore

* fix: wrap wandb functions

* fix: forget to send last samples

* chore: use valid type

* chore: use None when not type-checking

* chore: resolved complicated logic

* fix: follow log_every

---------

Co-authored-by: Plat <github@p1at.dev>
Co-authored-by: Jaret Burkett <jaretburkett@gmail.com>
This commit is contained in:
Plat
2024-09-20 11:01:01 +09:00
committed by GitHub
parent 951e223481
commit 79b4e04b80
6 changed files with 136 additions and 7 deletions

View File

@@ -13,7 +13,9 @@ SaveFormat = Literal['safetensors', 'diffusers']
if TYPE_CHECKING:
from toolkit.guidance import GuidanceType
from toolkit.logging import EmptyLogger
else:
EmptyLogger = None
class SaveConfig:
def __init__(self, **kwargs):
@@ -27,11 +29,13 @@ class SaveConfig:
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
class LogingConfig:
class LoggingConfig:
def __init__(self, **kwargs):
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
self.run_name: str = kwargs.get('run_name', None)
class SampleConfig:
@@ -655,6 +659,7 @@ class GenerateImageConfig:
extra_kwargs: dict = None, # extra data to save with prompt file
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
extra_values: List[float] = None, # extra values to save with prompt file
logger: Optional[EmptyLogger] = None,
):
self.width: int = width
self.height: int = height
@@ -712,6 +717,8 @@ class GenerateImageConfig:
self.height = max(64, self.height - self.height % 8) # round to divisible by 8
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
self.logger = logger
def set_gen_time(self, gen_time: int = None):
if gen_time is not None:
self.gen_time = gen_time
@@ -855,3 +862,9 @@ class GenerateImageConfig:
):
# this is called after prompt embeds are encoded. We can override them in the future here
pass
def log_image(self, image, count: int = 0, max_count=0):
if self.logger is None:
return
self.logger.log_image(image, count, self.prompt)

84
toolkit/logging.py Normal file
View File

@@ -0,0 +1,84 @@
from typing import OrderedDict, Optional
from PIL import Image
from toolkit.config_modules import LoggingConfig
# Base logger class
# This class does nothing, it's just a placeholder
class EmptyLogger:
def __init__(self, *args, **kwargs) -> None:
pass
# start logging the training
def start(self):
pass
# collect the log to send
def log(self, *args, **kwargs):
pass
# send the log
def commit(self, step: Optional[int] = None):
pass
# log image
def log_image(self, *args, **kwargs):
pass
# finish logging
def finish(self):
pass
# Wandb logger class
# This class logs the data to wandb
class WandbLogger(EmptyLogger):
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
self.project = project
self.run_name = run_name
self.config = config
def start(self):
try:
import wandb
except ImportError:
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
# send the whole config to wandb
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
self.run = run
self._log = wandb.log # log function
self._image = wandb.Image # image object
def log(self, *args, **kwargs):
# when commit is False, wandb increments the step,
# but we don't want that to happen, so we set commit=False
self._log(*args, **kwargs, commit=False)
def commit(self, step: Optional[int] = None):
# after overall one step is done, we commit the log
# by log empty object with commit=True
self._log({}, step=step, commit=True)
def log_image(
self,
image: Image,
id, # sample index
caption: str | None = None, # positive prompt
*args,
**kwargs,
):
# create a wandb image object and log it
image = self._image(image, caption=caption, *args, **kwargs)
self._log({f"sample_{id}": image}, commit=False)
def finish(self):
self.run.finish()
# create logger based on the logging config
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
if logging_config.use_wandb:
project_name = logging_config.project_name
run_name = logging_config.run_name
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
else:
return EmptyLogger()

View File

@@ -1318,6 +1318,7 @@ class StableDiffusion:
).images[0]
gen_config.save_image(img, i)
gen_config.log_image(img, i)
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()