Built base interfaces for a DTO to handle batch infomation transports for the dataloader

This commit is contained in:
Jaret Burkett
2023-08-28 12:43:31 -06:00
parent 71da78c8af
commit e866c75638
7 changed files with 186 additions and 76 deletions

View File

@@ -3,6 +3,7 @@ import time
from typing import List, Optional, Literal
import random
ImgExt = Literal['jpg', 'png', 'webp']
class SaveConfig:
def __init__(self, **kwargs):
@@ -31,6 +32,7 @@ class SampleConfig:
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg')
class NetworkConfig:
@@ -158,7 +160,10 @@ class SliderConfig:
class DatasetConfig:
caption_type: Literal["txt", "caption"] = 'txt'
"""
Dataset config for sd-datasets
"""
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
@@ -172,6 +177,10 @@ class DatasetConfig:
self.buckets: bool = kwargs.get('buckets', False)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
class GenerateImageConfig:
@@ -191,7 +200,7 @@ class GenerateImageConfig:
# the tag [time] will be replaced with milliseconds since epoch
output_path: str = None, # full image path
output_folder: str = None, # folder to save image in if output_path is not specified
output_ext: str = 'png', # extension to save image as if output_path is not specified
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
):