mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -54,9 +54,10 @@ class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||
self.tokens = kwargs.get('tokens', 4)
|
||||
self.init_words = kwargs.get('init_phrase', '*')
|
||||
self.init_words = kwargs.get('init_words', '*')
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
@@ -75,6 +76,7 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -165,6 +167,7 @@ class DatasetConfig:
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
self.scale: float = kwargs.get('scale', 1.0)
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
|
||||
Reference in New Issue
Block a user