mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add loss graph to the ui
This commit is contained in:
@@ -35,6 +35,7 @@ class LoggingConfig:
|
||||
self.log_every: int = kwargs.get('log_every', 100)
|
||||
self.verbose: bool = kwargs.get('verbose', False)
|
||||
self.use_wandb: bool = kwargs.get('use_wandb', False)
|
||||
self.use_ui_logger: bool = kwargs.get('use_ui_logger', False)
|
||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||
self.run_name: str = kwargs.get('run_name', None)
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ from typing import OrderedDict, Optional
|
||||
from PIL import Image
|
||||
|
||||
from toolkit.config_modules import LoggingConfig
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
|
||||
# Base logger class
|
||||
# This class does nothing, it's just a placeholder
|
||||
@@ -12,11 +17,11 @@ class EmptyLogger:
|
||||
# start logging the training
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
# collect the log to send
|
||||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# send the log
|
||||
def commit(self, step: Optional[int] = None):
|
||||
pass
|
||||
@@ -29,6 +34,7 @@ class EmptyLogger:
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
# Wandb logger class
|
||||
# This class logs the data to wandb
|
||||
class WandbLogger(EmptyLogger):
|
||||
@@ -41,13 +47,15 @@ class WandbLogger(EmptyLogger):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
|
||||
|
||||
raise ImportError(
|
||||
"Failed to import wandb. Please install wandb by running `pip install wandb`"
|
||||
)
|
||||
|
||||
# send the whole config to wandb
|
||||
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
|
||||
self.run = run
|
||||
self._log = wandb.log # log function
|
||||
self._image = wandb.Image # image object
|
||||
self._log = wandb.log # log function
|
||||
self._image = wandb.Image # image object
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
# when commit is False, wandb increments the step,
|
||||
@@ -74,11 +82,235 @@ class WandbLogger(EmptyLogger):
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
class UILogger:
|
||||
def __init__(
|
||||
self,
|
||||
log_file: str,
|
||||
flush_every_n: int = 256,
|
||||
flush_every_secs: float = 0.25,
|
||||
) -> None:
|
||||
self.log_file = log_file
|
||||
self._log_to_commit: Dict[str, Any] = {}
|
||||
|
||||
self._con: Optional[sqlite3.Connection] = None
|
||||
self._started = False
|
||||
|
||||
self._step_counter = 0
|
||||
|
||||
# buffered writes
|
||||
self._pending_steps: List[Tuple[int, float]] = []
|
||||
self._pending_metrics: List[
|
||||
Tuple[int, str, Optional[float], Optional[str]]
|
||||
] = []
|
||||
self._pending_key_minmax: Dict[str, Tuple[int, int]] = {}
|
||||
|
||||
self._flush_every_n = int(flush_every_n)
|
||||
self._flush_every_secs = float(flush_every_secs)
|
||||
self._last_flush = time.time()
|
||||
|
||||
# start logging the training
|
||||
def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
parent = os.path.dirname(os.path.abspath(self.log_file))
|
||||
if parent and not os.path.exists(parent):
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None)
|
||||
self._con.execute("PRAGMA journal_mode=WAL;")
|
||||
self._con.execute("PRAGMA synchronous=NORMAL;")
|
||||
self._con.execute("PRAGMA temp_store=MEMORY;")
|
||||
self._con.execute("PRAGMA foreign_keys=ON;")
|
||||
self._con.execute("PRAGMA busy_timeout=30000;")
|
||||
|
||||
self._init_schema(self._con)
|
||||
|
||||
self._started = True
|
||||
self._last_flush = time.time()
|
||||
|
||||
# collect the log to send
|
||||
def log(self, log_dict):
|
||||
# log_dict is like {'learning_rate': learning_rate}
|
||||
if not isinstance(log_dict, dict):
|
||||
raise TypeError("log_dict must be a dict")
|
||||
self._log_to_commit.update(log_dict)
|
||||
|
||||
# send the log
|
||||
def commit(self, step: Optional[int] = None):
|
||||
if not self._started:
|
||||
self.start()
|
||||
|
||||
if not self._log_to_commit:
|
||||
return
|
||||
|
||||
if step is None:
|
||||
step = self._step_counter
|
||||
self._step_counter += 1
|
||||
else:
|
||||
step = int(step)
|
||||
if step >= self._step_counter:
|
||||
self._step_counter = step + 1
|
||||
|
||||
wall_time = time.time()
|
||||
|
||||
# buffer step row (upsert later)
|
||||
self._pending_steps.append((step, wall_time))
|
||||
|
||||
# buffer metrics rows + key min/max updates
|
||||
for k, v in self._log_to_commit.items():
|
||||
k = k if isinstance(k, str) else str(k)
|
||||
vr, vt = self._coerce_value(v)
|
||||
|
||||
self._pending_metrics.append((step, k, vr, vt))
|
||||
|
||||
if k in self._pending_key_minmax:
|
||||
lo, hi = self._pending_key_minmax[k]
|
||||
if step < lo:
|
||||
lo = step
|
||||
if step > hi:
|
||||
hi = step
|
||||
self._pending_key_minmax[k] = (lo, hi)
|
||||
else:
|
||||
self._pending_key_minmax[k] = (step, step)
|
||||
|
||||
self._log_to_commit = {}
|
||||
|
||||
# flush conditions
|
||||
now = time.time()
|
||||
if (
|
||||
len(self._pending_metrics) >= self._flush_every_n
|
||||
or (now - self._last_flush) >= self._flush_every_secs
|
||||
):
|
||||
self._flush()
|
||||
|
||||
# log image
|
||||
def log_image(self, *args, **kwargs):
|
||||
# this doesnt log images for now
|
||||
pass
|
||||
|
||||
# finish logging
|
||||
def finish(self):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
self._flush()
|
||||
|
||||
assert self._con is not None
|
||||
self._con.close()
|
||||
self._con = None
|
||||
self._started = False
|
||||
|
||||
# -------------------------
|
||||
# internal
|
||||
# -------------------------
|
||||
|
||||
def _init_schema(self, con: sqlite3.Connection) -> None:
|
||||
con.execute("BEGIN;")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS steps (
|
||||
step INTEGER PRIMARY KEY,
|
||||
wall_time REAL NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metric_keys (
|
||||
key TEXT PRIMARY KEY,
|
||||
first_seen_step INTEGER,
|
||||
last_seen_step INTEGER
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metrics (
|
||||
step INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value_real REAL,
|
||||
value_text TEXT,
|
||||
PRIMARY KEY (step, key),
|
||||
FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);"
|
||||
)
|
||||
|
||||
con.execute("COMMIT;")
|
||||
|
||||
def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]:
|
||||
if v is None:
|
||||
return None, None
|
||||
if isinstance(v, bool):
|
||||
return float(int(v)), None
|
||||
if isinstance(v, (int, float)):
|
||||
return float(v), None
|
||||
try:
|
||||
return float(v), None # type: ignore[arg-type]
|
||||
except Exception:
|
||||
return None, str(v)
|
||||
|
||||
def _flush(self) -> None:
|
||||
if not self._pending_steps and not self._pending_metrics:
|
||||
return
|
||||
|
||||
assert self._con is not None
|
||||
con = self._con
|
||||
|
||||
con.execute("BEGIN;")
|
||||
|
||||
# steps upsert
|
||||
if self._pending_steps:
|
||||
con.executemany(
|
||||
"INSERT INTO steps(step, wall_time) VALUES(?, ?) "
|
||||
"ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;",
|
||||
self._pending_steps,
|
||||
)
|
||||
|
||||
# keys table upsert (maintains list of keys + seen range)
|
||||
if self._pending_key_minmax:
|
||||
con.executemany(
|
||||
"INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET "
|
||||
"first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), "
|
||||
"last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);",
|
||||
[(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()],
|
||||
)
|
||||
|
||||
# metrics upsert
|
||||
if self._pending_metrics:
|
||||
con.executemany(
|
||||
"INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) "
|
||||
"ON CONFLICT(step, key) DO UPDATE SET "
|
||||
"value_real=excluded.value_real, value_text=excluded.value_text;",
|
||||
self._pending_metrics,
|
||||
)
|
||||
|
||||
con.execute("COMMIT;")
|
||||
|
||||
self._pending_steps.clear()
|
||||
self._pending_metrics.clear()
|
||||
self._pending_key_minmax.clear()
|
||||
self._last_flush = time.time()
|
||||
|
||||
|
||||
# create logger based on the logging config
|
||||
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
|
||||
def create_logger(
|
||||
logging_config: LoggingConfig,
|
||||
all_config: OrderedDict,
|
||||
save_root: Optional[str] = None,
|
||||
):
|
||||
if logging_config.use_wandb:
|
||||
project_name = logging_config.project_name
|
||||
run_name = logging_config.run_name
|
||||
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
|
||||
elif logging_config.use_ui_logger:
|
||||
if save_root is None:
|
||||
raise ValueError("save_root must be provided when using UILogger")
|
||||
log_file = os.path.join(save_root, "loss_log.db")
|
||||
return UILogger(log_file=log_file)
|
||||
else:
|
||||
return EmptyLogger()
|
||||
|
||||
Reference in New Issue
Block a user