mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
* wandb logging * fix: start logging before train loop * chore: add wandb dir to gitignore * fix: wrap wandb functions * fix: forget to send last samples * chore: use valid type * chore: use None when not type-checking * chore: resolved complicated logic * fix: follow log_every --------- Co-authored-by: Plat <github@p1at.dev> Co-authored-by: Jaret Burkett <jaretburkett@gmail.com>
85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
from typing import OrderedDict, Optional
|
|
from PIL import Image
|
|
|
|
from toolkit.config_modules import LoggingConfig
|
|
|
|
# Base logger class
|
|
# This class does nothing, it's just a placeholder
|
|
class EmptyLogger:
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
pass
|
|
|
|
# start logging the training
|
|
def start(self):
|
|
pass
|
|
|
|
# collect the log to send
|
|
def log(self, *args, **kwargs):
|
|
pass
|
|
|
|
# send the log
|
|
def commit(self, step: Optional[int] = None):
|
|
pass
|
|
|
|
# log image
|
|
def log_image(self, *args, **kwargs):
|
|
pass
|
|
|
|
# finish logging
|
|
def finish(self):
|
|
pass
|
|
|
|
# Wandb logger class
|
|
# This class logs the data to wandb
|
|
class WandbLogger(EmptyLogger):
|
|
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
|
|
self.project = project
|
|
self.run_name = run_name
|
|
self.config = config
|
|
|
|
def start(self):
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
|
|
|
|
# send the whole config to wandb
|
|
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
|
|
self.run = run
|
|
self._log = wandb.log # log function
|
|
self._image = wandb.Image # image object
|
|
|
|
def log(self, *args, **kwargs):
|
|
# when commit is False, wandb increments the step,
|
|
# but we don't want that to happen, so we set commit=False
|
|
self._log(*args, **kwargs, commit=False)
|
|
|
|
def commit(self, step: Optional[int] = None):
|
|
# after overall one step is done, we commit the log
|
|
# by log empty object with commit=True
|
|
self._log({}, step=step, commit=True)
|
|
|
|
def log_image(
|
|
self,
|
|
image: Image,
|
|
id, # sample index
|
|
caption: str | None = None, # positive prompt
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
# create a wandb image object and log it
|
|
image = self._image(image, caption=caption, *args, **kwargs)
|
|
self._log({f"sample_{id}": image}, commit=False)
|
|
|
|
def finish(self):
|
|
self.run.finish()
|
|
|
|
# create logger based on the logging config
|
|
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
|
|
if logging_config.use_wandb:
|
|
project_name = logging_config.project_name
|
|
run_name = logging_config.run_name
|
|
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
|
|
else:
|
|
return EmptyLogger()
|