fix loras on nf4 models when activate "loras in fp16"

This commit is contained in:
layerdiffusion
2024-08-20 01:29:52 -07:00
parent 65ec461f8a
commit 6f411a4940
2 changed files with 9 additions and 1 deletions

View File

@@ -342,7 +342,7 @@ class ForgeOperations:
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
@@ -356,6 +356,11 @@ try:
# And it only invokes one time, and most linear does not have bias
self.bias = utils.tensor2parameter(self.bias.to(x.dtype))
if hasattr(self, 'forge_online_loras'):
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized: