mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
317 lines
9.4 KiB
Python
317 lines
9.4 KiB
Python
from typing import OrderedDict, Optional
|
|
from PIL import Image
|
|
|
|
from toolkit.config_modules import LoggingConfig
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
from typing import Any, Dict, Tuple, List
|
|
|
|
|
|
# Base logger class
|
|
# This class does nothing, it's just a placeholder
|
|
class EmptyLogger:
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
pass
|
|
|
|
# start logging the training
|
|
def start(self):
|
|
pass
|
|
|
|
# collect the log to send
|
|
def log(self, *args, **kwargs):
|
|
pass
|
|
|
|
# send the log
|
|
def commit(self, step: Optional[int] = None):
|
|
pass
|
|
|
|
# log image
|
|
def log_image(self, *args, **kwargs):
|
|
pass
|
|
|
|
# finish logging
|
|
def finish(self):
|
|
pass
|
|
|
|
|
|
# Wandb logger class
|
|
# This class logs the data to wandb
|
|
class WandbLogger(EmptyLogger):
|
|
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
|
|
self.project = project
|
|
self.run_name = run_name
|
|
self.config = config
|
|
|
|
def start(self):
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Failed to import wandb. Please install wandb by running `pip install wandb`"
|
|
)
|
|
|
|
# send the whole config to wandb
|
|
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
|
|
self.run = run
|
|
self._log = wandb.log # log function
|
|
self._image = wandb.Image # image object
|
|
|
|
def log(self, *args, **kwargs):
|
|
# when commit is False, wandb increments the step,
|
|
# but we don't want that to happen, so we set commit=False
|
|
self._log(*args, **kwargs, commit=False)
|
|
|
|
def commit(self, step: Optional[int] = None):
|
|
# after overall one step is done, we commit the log
|
|
# by log empty object with commit=True
|
|
self._log({}, step=step, commit=True)
|
|
|
|
def log_image(
|
|
self,
|
|
image: Image,
|
|
id, # sample index
|
|
caption: str | None = None, # positive prompt
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
# create a wandb image object and log it
|
|
image = self._image(image, caption=caption, *args, **kwargs)
|
|
self._log({f"sample_{id}": image}, commit=False)
|
|
|
|
def finish(self):
|
|
self.run.finish()
|
|
|
|
|
|
class UILogger:
|
|
def __init__(
|
|
self,
|
|
log_file: str,
|
|
flush_every_n: int = 256,
|
|
flush_every_secs: float = 0.25,
|
|
) -> None:
|
|
self.log_file = log_file
|
|
self._log_to_commit: Dict[str, Any] = {}
|
|
|
|
self._con: Optional[sqlite3.Connection] = None
|
|
self._started = False
|
|
|
|
self._step_counter = 0
|
|
|
|
# buffered writes
|
|
self._pending_steps: List[Tuple[int, float]] = []
|
|
self._pending_metrics: List[
|
|
Tuple[int, str, Optional[float], Optional[str]]
|
|
] = []
|
|
self._pending_key_minmax: Dict[str, Tuple[int, int]] = {}
|
|
|
|
self._flush_every_n = int(flush_every_n)
|
|
self._flush_every_secs = float(flush_every_secs)
|
|
self._last_flush = time.time()
|
|
|
|
# start logging the training
|
|
def start(self):
|
|
if self._started:
|
|
return
|
|
|
|
parent = os.path.dirname(os.path.abspath(self.log_file))
|
|
if parent and not os.path.exists(parent):
|
|
os.makedirs(parent, exist_ok=True)
|
|
|
|
self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None)
|
|
self._con.execute("PRAGMA journal_mode=WAL;")
|
|
self._con.execute("PRAGMA synchronous=NORMAL;")
|
|
self._con.execute("PRAGMA temp_store=MEMORY;")
|
|
self._con.execute("PRAGMA foreign_keys=ON;")
|
|
self._con.execute("PRAGMA busy_timeout=30000;")
|
|
|
|
self._init_schema(self._con)
|
|
|
|
self._started = True
|
|
self._last_flush = time.time()
|
|
|
|
# collect the log to send
|
|
def log(self, log_dict):
|
|
# log_dict is like {'learning_rate': learning_rate}
|
|
if not isinstance(log_dict, dict):
|
|
raise TypeError("log_dict must be a dict")
|
|
self._log_to_commit.update(log_dict)
|
|
|
|
# send the log
|
|
def commit(self, step: Optional[int] = None):
|
|
if not self._started:
|
|
self.start()
|
|
|
|
if not self._log_to_commit:
|
|
return
|
|
|
|
if step is None:
|
|
step = self._step_counter
|
|
self._step_counter += 1
|
|
else:
|
|
step = int(step)
|
|
if step >= self._step_counter:
|
|
self._step_counter = step + 1
|
|
|
|
wall_time = time.time()
|
|
|
|
# buffer step row (upsert later)
|
|
self._pending_steps.append((step, wall_time))
|
|
|
|
# buffer metrics rows + key min/max updates
|
|
for k, v in self._log_to_commit.items():
|
|
k = k if isinstance(k, str) else str(k)
|
|
vr, vt = self._coerce_value(v)
|
|
|
|
self._pending_metrics.append((step, k, vr, vt))
|
|
|
|
if k in self._pending_key_minmax:
|
|
lo, hi = self._pending_key_minmax[k]
|
|
if step < lo:
|
|
lo = step
|
|
if step > hi:
|
|
hi = step
|
|
self._pending_key_minmax[k] = (lo, hi)
|
|
else:
|
|
self._pending_key_minmax[k] = (step, step)
|
|
|
|
self._log_to_commit = {}
|
|
|
|
# flush conditions
|
|
now = time.time()
|
|
if (
|
|
len(self._pending_metrics) >= self._flush_every_n
|
|
or (now - self._last_flush) >= self._flush_every_secs
|
|
):
|
|
self._flush()
|
|
|
|
# log image
|
|
def log_image(self, *args, **kwargs):
|
|
# this doesnt log images for now
|
|
pass
|
|
|
|
# finish logging
|
|
def finish(self):
|
|
if not self._started:
|
|
return
|
|
|
|
self._flush()
|
|
|
|
assert self._con is not None
|
|
self._con.close()
|
|
self._con = None
|
|
self._started = False
|
|
|
|
# -------------------------
|
|
# internal
|
|
# -------------------------
|
|
|
|
def _init_schema(self, con: sqlite3.Connection) -> None:
|
|
con.execute("BEGIN;")
|
|
|
|
con.execute("""
|
|
CREATE TABLE IF NOT EXISTS steps (
|
|
step INTEGER PRIMARY KEY,
|
|
wall_time REAL NOT NULL
|
|
);
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE IF NOT EXISTS metric_keys (
|
|
key TEXT PRIMARY KEY,
|
|
first_seen_step INTEGER,
|
|
last_seen_step INTEGER
|
|
);
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE IF NOT EXISTS metrics (
|
|
step INTEGER NOT NULL,
|
|
key TEXT NOT NULL,
|
|
value_real REAL,
|
|
value_text TEXT,
|
|
PRIMARY KEY (step, key),
|
|
FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE
|
|
);
|
|
""")
|
|
|
|
con.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);"
|
|
)
|
|
|
|
con.execute("COMMIT;")
|
|
|
|
def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]:
|
|
if v is None:
|
|
return None, None
|
|
if isinstance(v, bool):
|
|
return float(int(v)), None
|
|
if isinstance(v, (int, float)):
|
|
return float(v), None
|
|
try:
|
|
return float(v), None # type: ignore[arg-type]
|
|
except Exception:
|
|
return None, str(v)
|
|
|
|
def _flush(self) -> None:
|
|
if not self._pending_steps and not self._pending_metrics:
|
|
return
|
|
|
|
assert self._con is not None
|
|
con = self._con
|
|
|
|
con.execute("BEGIN;")
|
|
|
|
# steps upsert
|
|
if self._pending_steps:
|
|
con.executemany(
|
|
"INSERT INTO steps(step, wall_time) VALUES(?, ?) "
|
|
"ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;",
|
|
self._pending_steps,
|
|
)
|
|
|
|
# keys table upsert (maintains list of keys + seen range)
|
|
if self._pending_key_minmax:
|
|
con.executemany(
|
|
"INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) "
|
|
"ON CONFLICT(key) DO UPDATE SET "
|
|
"first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), "
|
|
"last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);",
|
|
[(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()],
|
|
)
|
|
|
|
# metrics upsert
|
|
if self._pending_metrics:
|
|
con.executemany(
|
|
"INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) "
|
|
"ON CONFLICT(step, key) DO UPDATE SET "
|
|
"value_real=excluded.value_real, value_text=excluded.value_text;",
|
|
self._pending_metrics,
|
|
)
|
|
|
|
con.execute("COMMIT;")
|
|
|
|
self._pending_steps.clear()
|
|
self._pending_metrics.clear()
|
|
self._pending_key_minmax.clear()
|
|
self._last_flush = time.time()
|
|
|
|
|
|
# create logger based on the logging config
|
|
def create_logger(
|
|
logging_config: LoggingConfig,
|
|
all_config: OrderedDict,
|
|
save_root: Optional[str] = None,
|
|
):
|
|
if logging_config.use_wandb:
|
|
project_name = logging_config.project_name
|
|
run_name = logging_config.run_name
|
|
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
|
|
elif logging_config.use_ui_logger:
|
|
if save_root is None:
|
|
raise ValueError("save_root must be provided when using UILogger")
|
|
log_file = os.path.join(save_root, "loss_log.db")
|
|
return UILogger(log_file=log_file)
|
|
else:
|
|
return EmptyLogger()
|