Files
ai-toolkit/toolkit/config_modules.py
2023-08-01 13:49:54 -06:00

104 lines
4.2 KiB
Python

from typing import List
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
class LogingConfig:
def __init__(self, **kwargs):
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
class SampleConfig:
def __init__(self, **kwargs):
self.sample_every: int = kwargs.get('sample_every', 100)
self.width: int = kwargs.get('width', 512)
self.height: int = kwargs.get('height', 512)
self.prompts: list[str] = kwargs.get('prompts', [])
self.neg = kwargs.get('neg', False)
self.seed = kwargs.get('seed', 0)
self.walk_seed = kwargs.get('walk_seed', False)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
class NetworkConfig:
def __init__(self, **kwargs):
self.type: str = kwargs.get('type', 'lora')
rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None)
if rank is not None:
self.rank: int = rank # rank for backward compatibility
self.linear: int = rank
elif linear is not None:
self.rank: int = linear
self.linear: int = linear
self.conv: int = kwargs.get('conv', None)
self.alpha: float = kwargs.get('alpha', 1.0)
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
self.steps: int = kwargs.get('steps', 1000)
self.lr = kwargs.get('lr', 1e-6)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
class SliderTargetConfig:
def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', '')
self.positive: str = kwargs.get('positive', '')
self.negative: str = kwargs.get('negative', '')
self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0)
class SliderConfigAnchors:
def __init__(self, **kwargs):
self.prompt = kwargs.get('prompt', '')
self.neg_prompt = kwargs.get('neg_prompt', '')
self.multiplier = kwargs.get('multiplier', 1.0)
class SliderConfig:
def __init__(self, **kwargs):
targets = kwargs.get('targets', [])
targets = [SliderTargetConfig(**target) for target in targets]
self.targets: List[SliderTargetConfig] = targets
anchors = kwargs.get('anchors', [])
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
self.anchors: List[SliderConfigAnchors] = anchors
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)