mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue where ARA was not working when using memory manager
This commit is contained in:
@@ -59,6 +59,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
module_dropout=None,
|
module_dropout=None,
|
||||||
network: 'LoRASpecialNetwork' = None,
|
network: 'LoRASpecialNetwork' = None,
|
||||||
use_bias: bool = False,
|
use_bias: bool = False,
|
||||||
|
is_ara: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.can_merge_in = True
|
self.can_merge_in = True
|
||||||
@@ -68,6 +69,10 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
self.lora_name = lora_name
|
self.lora_name = lora_name
|
||||||
self.orig_module_ref = weakref.ref(org_module)
|
self.orig_module_ref = weakref.ref(org_module)
|
||||||
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
|
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
|
||||||
|
|
||||||
|
# if is ara lora module, mark it on the layer so memory manager can handle it
|
||||||
|
if is_ara:
|
||||||
|
org_module.ara_lora_ref = weakref.ref(self)
|
||||||
# check if parent has bias. if not force use_bias to False
|
# check if parent has bias. if not force use_bias to False
|
||||||
if org_module.bias is None:
|
if org_module.bias is None:
|
||||||
use_bias = False
|
use_bias = False
|
||||||
@@ -193,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
is_assistant_adapter: bool = False,
|
is_assistant_adapter: bool = False,
|
||||||
is_transformer: bool = False,
|
is_transformer: bool = False,
|
||||||
base_model: 'StableDiffusion' = None,
|
base_model: 'StableDiffusion' = None,
|
||||||
|
is_ara: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -247,6 +253,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.network_type = network_type
|
self.network_type = network_type
|
||||||
self.is_assistant_adapter = is_assistant_adapter
|
self.is_assistant_adapter = is_assistant_adapter
|
||||||
self.full_rank = network_type.lower() == "fullrank"
|
self.full_rank = network_type.lower() == "fullrank"
|
||||||
|
self.is_ara = is_ara
|
||||||
if self.network_type.lower() == "dora":
|
if self.network_type.lower() == "dora":
|
||||||
self.module_class = DoRAModule
|
self.module_class = DoRAModule
|
||||||
module_class = DoRAModule
|
module_class = DoRAModule
|
||||||
@@ -419,6 +426,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
|
|
||||||
if self.network_type.lower() == "lokr":
|
if self.network_type.lower() == "lokr":
|
||||||
module_kwargs["factor"] = self.network_config.lokr_factor
|
module_kwargs["factor"] = self.network_config.lokr_factor
|
||||||
|
|
||||||
|
if self.is_ara:
|
||||||
|
module_kwargs["is_ara"] = True
|
||||||
|
|
||||||
lora = module_class(
|
lora = module_class(
|
||||||
lora_name,
|
lora_name,
|
||||||
|
|||||||
@@ -560,7 +560,11 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
_move_params_to_cpu_and_pin(self.module)
|
_move_params_to_cpu_and_pin(self.module)
|
||||||
|
|
||||||
# 2) Hijack forward
|
# 2) Hijack forward
|
||||||
self._original_forward = getattr(self.module, "forward")
|
if hasattr(self.module, "ara_lora_ref"):
|
||||||
|
# ARA, we need to replace the lora forward
|
||||||
|
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
|
||||||
|
else:
|
||||||
|
self._original_forward = getattr(self.module, "forward")
|
||||||
|
|
||||||
def _mm_forward(x, *args, **kwargs):
|
def _mm_forward(x, *args, **kwargs):
|
||||||
# ensure we only use expected signature (Linear: x)
|
# ensure we only use expected signature (Linear: x)
|
||||||
@@ -575,7 +579,10 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
# NOTE: do NOT move params to device here; autograd fn streams & bounces them
|
# NOTE: do NOT move params to device here; autograd fn streams & bounces them
|
||||||
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
|
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
|
||||||
|
|
||||||
self.module.forward = _mm_forward
|
if hasattr(self.module, "ara_lora_ref"):
|
||||||
|
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||||
|
else:
|
||||||
|
self.module.forward = _mm_forward
|
||||||
|
|
||||||
|
|
||||||
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||||
@@ -608,7 +615,11 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
groups = self.module.groups
|
groups = self.module.groups
|
||||||
|
|
||||||
# 2) Hijack forward
|
# 2) Hijack forward
|
||||||
self._original_forward = getattr(self.module, "forward")
|
if hasattr(self.module, "ara_lora_ref"):
|
||||||
|
# ARA, we need to replace the lora forward
|
||||||
|
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
|
||||||
|
else:
|
||||||
|
self._original_forward = getattr(self.module, "forward")
|
||||||
|
|
||||||
def _mm_forward(x, *args, **kwargs):
|
def _mm_forward(x, *args, **kwargs):
|
||||||
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
|
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
|
||||||
@@ -623,4 +634,7 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
|
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
|
||||||
)
|
)
|
||||||
|
|
||||||
self.module.forward = _mm_forward
|
if hasattr(self.module, "ara_lora_ref"):
|
||||||
|
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||||
|
else:
|
||||||
|
self.module.forward = _mm_forward
|
||||||
|
|||||||
@@ -247,6 +247,7 @@ def quantize_model(
|
|||||||
transformer_only=network_config.transformer_only,
|
transformer_only=network_config.transformer_only,
|
||||||
is_transformer=base_model.is_transformer,
|
is_transformer=base_model.is_transformer,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
is_ara=True,
|
||||||
**network_kwargs
|
**network_kwargs
|
||||||
)
|
)
|
||||||
network.apply_to(
|
network.apply_to(
|
||||||
|
|||||||
Reference in New Issue
Block a user