Files
ai-toolkit/toolkit/logging_aitk.py
2025-12-18 10:08:59 -07:00

317 lines
9.4 KiB
Python

from typing import OrderedDict, Optional
from PIL import Image
from toolkit.config_modules import LoggingConfig
import os
import sqlite3
import time
from typing import Any, Dict, Tuple, List
# Base logger class
# This class does nothing, it's just a placeholder
class EmptyLogger:
def __init__(self, *args, **kwargs) -> None:
pass
# start logging the training
def start(self):
pass
# collect the log to send
def log(self, *args, **kwargs):
pass
# send the log
def commit(self, step: Optional[int] = None):
pass
# log image
def log_image(self, *args, **kwargs):
pass
# finish logging
def finish(self):
pass
# Wandb logger class
# This class logs the data to wandb
class WandbLogger(EmptyLogger):
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
self.project = project
self.run_name = run_name
self.config = config
def start(self):
try:
import wandb
except ImportError:
raise ImportError(
"Failed to import wandb. Please install wandb by running `pip install wandb`"
)
# send the whole config to wandb
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
self.run = run
self._log = wandb.log # log function
self._image = wandb.Image # image object
def log(self, *args, **kwargs):
# when commit is False, wandb increments the step,
# but we don't want that to happen, so we set commit=False
self._log(*args, **kwargs, commit=False)
def commit(self, step: Optional[int] = None):
# after overall one step is done, we commit the log
# by log empty object with commit=True
self._log({}, step=step, commit=True)
def log_image(
self,
image: Image,
id, # sample index
caption: str | None = None, # positive prompt
*args,
**kwargs,
):
# create a wandb image object and log it
image = self._image(image, caption=caption, *args, **kwargs)
self._log({f"sample_{id}": image}, commit=False)
def finish(self):
self.run.finish()
class UILogger:
def __init__(
self,
log_file: str,
flush_every_n: int = 256,
flush_every_secs: float = 0.25,
) -> None:
self.log_file = log_file
self._log_to_commit: Dict[str, Any] = {}
self._con: Optional[sqlite3.Connection] = None
self._started = False
self._step_counter = 0
# buffered writes
self._pending_steps: List[Tuple[int, float]] = []
self._pending_metrics: List[
Tuple[int, str, Optional[float], Optional[str]]
] = []
self._pending_key_minmax: Dict[str, Tuple[int, int]] = {}
self._flush_every_n = int(flush_every_n)
self._flush_every_secs = float(flush_every_secs)
self._last_flush = time.time()
# start logging the training
def start(self):
if self._started:
return
parent = os.path.dirname(os.path.abspath(self.log_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None)
self._con.execute("PRAGMA journal_mode=WAL;")
self._con.execute("PRAGMA synchronous=NORMAL;")
self._con.execute("PRAGMA temp_store=MEMORY;")
self._con.execute("PRAGMA foreign_keys=ON;")
self._con.execute("PRAGMA busy_timeout=30000;")
self._init_schema(self._con)
self._started = True
self._last_flush = time.time()
# collect the log to send
def log(self, log_dict):
# log_dict is like {'learning_rate': learning_rate}
if not isinstance(log_dict, dict):
raise TypeError("log_dict must be a dict")
self._log_to_commit.update(log_dict)
# send the log
def commit(self, step: Optional[int] = None):
if not self._started:
self.start()
if not self._log_to_commit:
return
if step is None:
step = self._step_counter
self._step_counter += 1
else:
step = int(step)
if step >= self._step_counter:
self._step_counter = step + 1
wall_time = time.time()
# buffer step row (upsert later)
self._pending_steps.append((step, wall_time))
# buffer metrics rows + key min/max updates
for k, v in self._log_to_commit.items():
k = k if isinstance(k, str) else str(k)
vr, vt = self._coerce_value(v)
self._pending_metrics.append((step, k, vr, vt))
if k in self._pending_key_minmax:
lo, hi = self._pending_key_minmax[k]
if step < lo:
lo = step
if step > hi:
hi = step
self._pending_key_minmax[k] = (lo, hi)
else:
self._pending_key_minmax[k] = (step, step)
self._log_to_commit = {}
# flush conditions
now = time.time()
if (
len(self._pending_metrics) >= self._flush_every_n
or (now - self._last_flush) >= self._flush_every_secs
):
self._flush()
# log image
def log_image(self, *args, **kwargs):
# this doesnt log images for now
pass
# finish logging
def finish(self):
if not self._started:
return
self._flush()
assert self._con is not None
self._con.close()
self._con = None
self._started = False
# -------------------------
# internal
# -------------------------
def _init_schema(self, con: sqlite3.Connection) -> None:
con.execute("BEGIN;")
con.execute("""
CREATE TABLE IF NOT EXISTS steps (
step INTEGER PRIMARY KEY,
wall_time REAL NOT NULL
);
""")
con.execute("""
CREATE TABLE IF NOT EXISTS metric_keys (
key TEXT PRIMARY KEY,
first_seen_step INTEGER,
last_seen_step INTEGER
);
""")
con.execute("""
CREATE TABLE IF NOT EXISTS metrics (
step INTEGER NOT NULL,
key TEXT NOT NULL,
value_real REAL,
value_text TEXT,
PRIMARY KEY (step, key),
FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE
);
""")
con.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);"
)
con.execute("COMMIT;")
def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]:
if v is None:
return None, None
if isinstance(v, bool):
return float(int(v)), None
if isinstance(v, (int, float)):
return float(v), None
try:
return float(v), None # type: ignore[arg-type]
except Exception:
return None, str(v)
def _flush(self) -> None:
if not self._pending_steps and not self._pending_metrics:
return
assert self._con is not None
con = self._con
con.execute("BEGIN;")
# steps upsert
if self._pending_steps:
con.executemany(
"INSERT INTO steps(step, wall_time) VALUES(?, ?) "
"ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;",
self._pending_steps,
)
# keys table upsert (maintains list of keys + seen range)
if self._pending_key_minmax:
con.executemany(
"INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) "
"ON CONFLICT(key) DO UPDATE SET "
"first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), "
"last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);",
[(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()],
)
# metrics upsert
if self._pending_metrics:
con.executemany(
"INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) "
"ON CONFLICT(step, key) DO UPDATE SET "
"value_real=excluded.value_real, value_text=excluded.value_text;",
self._pending_metrics,
)
con.execute("COMMIT;")
self._pending_steps.clear()
self._pending_metrics.clear()
self._pending_key_minmax.clear()
self._last_flush = time.time()
# create logger based on the logging config
def create_logger(
logging_config: LoggingConfig,
all_config: OrderedDict,
save_root: Optional[str] = None,
):
if logging_config.use_wandb:
project_name = logging_config.project_name
run_name = logging_config.run_name
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
elif logging_config.use_ui_logger:
if save_root is None:
raise ValueError("save_root must be provided when using UILogger")
log_file = os.path.join(save_root, "loss_log.db")
return UILogger(log_file=log_file)
else:
return EmptyLogger()