mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Add method to do continuious lora merging in for low vram full finetuning.
This commit is contained in:
@@ -404,6 +404,15 @@ class ToolkitModuleMixin:
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def reset_weights(self: Module):
|
||||
# reset the weights to zero
|
||||
org_sd = self.state_dict()
|
||||
for key in org_sd.keys():
|
||||
# only reset lora up
|
||||
if 'lora_up' in key:
|
||||
org_sd[key] = torch.zeros_like(org_sd[key])
|
||||
self.load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
|
||||
@@ -811,6 +820,10 @@ class ToolkitNetworkMixin:
|
||||
# not supported
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
def reset_weights(self: Network):
|
||||
for module in self.get_all_modules():
|
||||
module.reset_weights()
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if self.network_type.lower() == 'dora':
|
||||
|
||||
Reference in New Issue
Block a user