mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Improved lorm extraction and training
This commit is contained in:
@@ -40,7 +40,48 @@ class SampleConfig:
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon']
|
||||
class LormModuleSettingsConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.contains: str = kwargs.get('contains', '4nt$3')
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
# min num parameters to attach to
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
|
||||
|
||||
class LoRMConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
module_settings = kwargs.get('module_settings', [])
|
||||
default_module_settings = {
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
}
|
||||
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
|
||||
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
|
||||
module_setting in module_settings]
|
||||
|
||||
def get_config_for_module(self, block_name):
|
||||
for setting in self.module_settings:
|
||||
contain_pieces = setting.contains.split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# try replacing the . with _
|
||||
contain_pieces = setting.contains.replace('.', '_').split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# do default
|
||||
return LormModuleSettingsConfig(**{
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
})
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon', 'lorm']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -58,12 +99,22 @@ class NetworkConfig:
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.normalize = kwargs.get('normalize', False)
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
self.lorm_config: Union[LoRMConfig, None] = None
|
||||
lorm = kwargs.get('lorm', None)
|
||||
if lorm is not None:
|
||||
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
|
||||
|
||||
if self.type == 'lorm':
|
||||
# set linear to arbitrary values so it makes them
|
||||
self.linear = 4
|
||||
self.rank = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
|
||||
@@ -90,6 +141,7 @@ class EmbeddingConfig:
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
@@ -138,7 +190,8 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target',
|
||||
'noise') # noise, source, unaugmented, differential_noise
|
||||
|
||||
# When a mask is passed in a dataset, and this is true,
|
||||
# we will predict noise without a the LoRa network and use the prediction as a target for
|
||||
@@ -151,7 +204,6 @@ class TrainConfig:
|
||||
self.match_adapter_chance = 1.0
|
||||
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
@@ -216,7 +268,7 @@ class SliderConfig:
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||
self.high_ram = kwargs.get('high_ram', False)
|
||||
|
||||
@@ -267,9 +319,11 @@ class DatasetConfig:
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
@@ -525,4 +579,4 @@ class GenerateImageConfig:
|
||||
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user