mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added ability to train control loras. Other important bug fixes thrown in
This commit is contained in:
@@ -106,7 +106,7 @@ class LoRMConfig:
|
||||
})
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon', 'lorm']
|
||||
NetworkType = Literal['lora', 'locon', 'lorm', 'lokr']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -151,7 +151,7 @@ class NetworkConfig:
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora']
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
@@ -234,6 +234,13 @@ class AdapterConfig:
|
||||
# for llm adapter
|
||||
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
||||
self.quantize_llm: bool = kwargs.get('quantize_llm', False)
|
||||
|
||||
# for control lora only
|
||||
lora_config: dict = kwargs.get('lora_config', None)
|
||||
if lora_config is not None:
|
||||
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
||||
else:
|
||||
self.lora_config = None
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -521,6 +528,32 @@ class ModelConfig:
|
||||
self.arch: ModelArch = kwargs.get("arch", None)
|
||||
|
||||
# handle migrating to new model arch
|
||||
if self.arch is not None:
|
||||
# reverse the arch to the old style
|
||||
if self.arch == 'sd2':
|
||||
self.is_v2 = True
|
||||
elif self.arch == 'sd3':
|
||||
self.is_v3 = True
|
||||
elif self.arch == 'sdxl':
|
||||
self.is_xl = True
|
||||
elif self.arch == 'pixart':
|
||||
self.is_pixart = True
|
||||
elif self.arch == 'pixart_sigma':
|
||||
self.is_pixart_sigma = True
|
||||
elif self.arch == 'auraflow':
|
||||
self.is_auraflow = True
|
||||
elif self.arch == 'flux':
|
||||
self.is_flux = True
|
||||
elif self.arch == 'flex2':
|
||||
self.is_flex2 = True
|
||||
elif self.arch == 'lumina2':
|
||||
self.is_lumina2 = True
|
||||
elif self.arch == 'vega':
|
||||
self.is_vega = True
|
||||
elif self.arch == 'ssd':
|
||||
self.is_ssd = True
|
||||
else:
|
||||
pass
|
||||
if self.arch is None:
|
||||
if kwargs.get('is_v2', False):
|
||||
self.arch = 'sd2'
|
||||
|
||||
Reference in New Issue
Block a user