mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Feat: Wandb logging (#95)
* wandb logging * fix: start logging before train loop * chore: add wandb dir to gitignore * fix: wrap wandb functions * fix: forget to send last samples * chore: use valid type * chore: use None when not type-checking * chore: resolved complicated logic * fix: follow log_every --------- Co-authored-by: Plat <github@p1at.dev> Co-authored-by: Jaret Burkett <jaretburkett@gmail.com>
This commit is contained in:
@@ -13,7 +13,9 @@ SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.guidance import GuidanceType
|
||||
|
||||
from toolkit.logging import EmptyLogger
|
||||
else:
|
||||
EmptyLogger = None
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -27,11 +29,13 @@ class SaveConfig:
|
||||
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
|
||||
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
|
||||
|
||||
class LogingConfig:
|
||||
class LoggingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.log_every: int = kwargs.get('log_every', 100)
|
||||
self.verbose: bool = kwargs.get('verbose', False)
|
||||
self.use_wandb: bool = kwargs.get('use_wandb', False)
|
||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||
self.run_name: str = kwargs.get('run_name', None)
|
||||
|
||||
|
||||
class SampleConfig:
|
||||
@@ -655,6 +659,7 @@ class GenerateImageConfig:
|
||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -712,6 +717,8 @@ class GenerateImageConfig:
|
||||
self.height = max(64, self.height - self.height % 8) # round to divisible by 8
|
||||
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
|
||||
|
||||
self.logger = logger
|
||||
|
||||
def set_gen_time(self, gen_time: int = None):
|
||||
if gen_time is not None:
|
||||
self.gen_time = gen_time
|
||||
@@ -855,3 +862,9 @@ class GenerateImageConfig:
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
|
||||
def log_image(self, image, count: int = 0, max_count=0):
|
||||
if self.logger is None:
|
||||
return
|
||||
|
||||
self.logger.log_image(image, count, self.prompt)
|
||||
Reference in New Issue
Block a user