Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -40,7 +40,48 @@ class SampleConfig:
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
NetworkType = Literal['lora', 'locon']
class LormModuleSettingsConfig:
def __init__(self, **kwargs):
self.contains: str = kwargs.get('contains', '4nt$3')
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
# min num parameters to attach to
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
class LoRMConfig:
def __init__(self, **kwargs):
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
module_settings = kwargs.get('module_settings', [])
default_module_settings = {
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
}
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
module_setting in module_settings]
def get_config_for_module(self, block_name):
for setting in self.module_settings:
contain_pieces = setting.contains.split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# try replacing the . with _
contain_pieces = setting.contains.replace('.', '_').split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# do default
return LormModuleSettingsConfig(**{
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
})
NetworkType = Literal['lora', 'locon', 'lorm']
class NetworkConfig:
@@ -58,12 +99,22 @@ class NetworkConfig:
self.alpha: float = kwargs.get('alpha', 1.0)
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.normalize = kwargs.get('normalize', False)
self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None)
if lorm is not None:
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
if self.type == 'lorm':
# set linear to arbitrary values so it makes them
self.linear = 4
self.rank = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
@@ -90,6 +141,7 @@ class EmbeddingConfig:
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
@@ -138,7 +190,8 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
self.loss_target: LossTarget = kwargs.get('loss_target',
'noise') # noise, source, unaugmented, differential_noise
# When a mask is passed in a dataset, and this is true,
# we will predict noise without a the LoRa network and use the prediction as a target for
@@ -151,7 +204,6 @@ class TrainConfig:
self.match_adapter_chance = 1.0
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
@@ -216,7 +268,7 @@ class SliderConfig:
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.high_ram = kwargs.get('high_ram', False)
@@ -267,9 +319,11 @@ class DatasetConfig:
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
self.poi: Union[str, None] = kwargs.get('poi',
None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
@@ -525,4 +579,4 @@ class GenerateImageConfig:
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
):
# this is called after prompt embeds are encoded. We can override them in the future here
pass
pass