mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Feat: Wandb logging (#95)
* wandb logging * fix: start logging before train loop * chore: add wandb dir to gitignore * fix: wrap wandb functions * fix: forget to send last samples * chore: use valid type * chore: use None when not type-checking * chore: resolved complicated logic * fix: follow log_every --------- Co-authored-by: Plat <github@p1at.dev> Co-authored-by: Jaret Burkett <jaretburkett@gmail.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -173,4 +173,5 @@ cython_debug/
|
||||
!/output/.gitkeep
|
||||
/extensions/*
|
||||
!/extensions/example
|
||||
/temp
|
||||
/temp
|
||||
/wandb
|
||||
|
||||
@@ -113,6 +113,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.taesd.requires_grad_(False)
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
# move vae to device if we did not cache latents
|
||||
|
||||
@@ -55,9 +55,9 @@ import gc
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig
|
||||
|
||||
from toolkit.logging import create_logger
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -102,7 +102,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
self.has_first_sample_requested = False
|
||||
self.first_sample_config = self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.logging_config = LoggingConfig(**self.get_conf('logging', {}))
|
||||
self.logger = create_logger(self.logging_config, config)
|
||||
self.optimizer: torch.optim.Optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
@@ -258,6 +259,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
refiner_start_at=sample_config.refiner_start_at,
|
||||
extra_values=sample_config.extra_values,
|
||||
logger=self.logger,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
@@ -568,7 +570,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return params
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
self.logger.start()
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
@@ -1627,6 +1629,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
@@ -1767,6 +1770,25 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# log to logger
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
elif self.logging_config.log_every is None:
|
||||
# log every step
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
|
||||
|
||||
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
@@ -1774,6 +1796,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.timer.print()
|
||||
self.timer.reset()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# commit log
|
||||
self.logger.commit(step=self.step_num)
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
@@ -1796,8 +1821,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.pipeline.disable_freeu()
|
||||
if not self.train_config.disable_sampling:
|
||||
self.sample(self.step_num)
|
||||
self.logger.commit(step=self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
self.logger.finish()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
|
||||
@@ -13,7 +13,9 @@ SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.guidance import GuidanceType
|
||||
|
||||
from toolkit.logging import EmptyLogger
|
||||
else:
|
||||
EmptyLogger = None
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -27,11 +29,13 @@ class SaveConfig:
|
||||
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
|
||||
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
|
||||
|
||||
class LogingConfig:
|
||||
class LoggingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.log_every: int = kwargs.get('log_every', 100)
|
||||
self.verbose: bool = kwargs.get('verbose', False)
|
||||
self.use_wandb: bool = kwargs.get('use_wandb', False)
|
||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||
self.run_name: str = kwargs.get('run_name', None)
|
||||
|
||||
|
||||
class SampleConfig:
|
||||
@@ -655,6 +659,7 @@ class GenerateImageConfig:
|
||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -712,6 +717,8 @@ class GenerateImageConfig:
|
||||
self.height = max(64, self.height - self.height % 8) # round to divisible by 8
|
||||
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
|
||||
|
||||
self.logger = logger
|
||||
|
||||
def set_gen_time(self, gen_time: int = None):
|
||||
if gen_time is not None:
|
||||
self.gen_time = gen_time
|
||||
@@ -855,3 +862,9 @@ class GenerateImageConfig:
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
|
||||
def log_image(self, image, count: int = 0, max_count=0):
|
||||
if self.logger is None:
|
||||
return
|
||||
|
||||
self.logger.log_image(image, count, self.prompt)
|
||||
84
toolkit/logging.py
Normal file
84
toolkit/logging.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import OrderedDict, Optional
|
||||
from PIL import Image
|
||||
|
||||
from toolkit.config_modules import LoggingConfig
|
||||
|
||||
# Base logger class
|
||||
# This class does nothing, it's just a placeholder
|
||||
class EmptyLogger:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
# start logging the training
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
# collect the log to send
|
||||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# send the log
|
||||
def commit(self, step: Optional[int] = None):
|
||||
pass
|
||||
|
||||
# log image
|
||||
def log_image(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# finish logging
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
# Wandb logger class
|
||||
# This class logs the data to wandb
|
||||
class WandbLogger(EmptyLogger):
|
||||
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
|
||||
self.project = project
|
||||
self.run_name = run_name
|
||||
self.config = config
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
|
||||
|
||||
# send the whole config to wandb
|
||||
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
|
||||
self.run = run
|
||||
self._log = wandb.log # log function
|
||||
self._image = wandb.Image # image object
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
# when commit is False, wandb increments the step,
|
||||
# but we don't want that to happen, so we set commit=False
|
||||
self._log(*args, **kwargs, commit=False)
|
||||
|
||||
def commit(self, step: Optional[int] = None):
|
||||
# after overall one step is done, we commit the log
|
||||
# by log empty object with commit=True
|
||||
self._log({}, step=step, commit=True)
|
||||
|
||||
def log_image(
|
||||
self,
|
||||
image: Image,
|
||||
id, # sample index
|
||||
caption: str | None = None, # positive prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# create a wandb image object and log it
|
||||
image = self._image(image, caption=caption, *args, **kwargs)
|
||||
self._log({f"sample_{id}": image}, commit=False)
|
||||
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
# create logger based on the logging config
|
||||
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
|
||||
if logging_config.use_wandb:
|
||||
project_name = logging_config.project_name
|
||||
run_name = logging_config.run_name
|
||||
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
|
||||
else:
|
||||
return EmptyLogger()
|
||||
@@ -1318,6 +1318,7 @@ class StableDiffusion:
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
gen_config.log_image(img, i)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
|
||||
Reference in New Issue
Block a user