mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.
This commit is contained in:
@@ -39,6 +39,7 @@ class SampleConfig:
|
||||
|
||||
NetworkType = Literal['lora', 'locon']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: NetworkType = kwargs.get('type', 'lora')
|
||||
@@ -58,6 +59,17 @@ class NetworkConfig:
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 16)
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||
@@ -66,9 +78,13 @@ class EmbeddingConfig:
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
|
||||
self.steps: int = kwargs.get('steps', 1000)
|
||||
self.lr = kwargs.get('lr', 1e-6)
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
@@ -80,8 +96,6 @@ class TrainConfig:
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
|
||||
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
@@ -255,6 +269,7 @@ class GenerateImageConfig:
|
||||
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
|
||||
output_tail: str = '', # tail to add to output filename
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
adapter_image_path: str = None, # path to adapter image
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -277,6 +292,7 @@ class GenerateImageConfig:
|
||||
self.add_prompt_file: bool = add_prompt_file
|
||||
self.output_tail: str = output_tail
|
||||
self.gen_time: int = int(time.time() * 1000)
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
Reference in New Issue
Block a user