fix loras on nf4 models when activate "loras in fp16"

This commit is contained in:
layerdiffusion
2024-08-20 01:29:52 -07:00
parent 65ec461f8a
commit 6f411a4940
2 changed files with 9 additions and 1 deletions

View File

@@ -342,7 +342,7 @@ class ForgeOperations:
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
@@ -356,6 +356,11 @@ try:
# And it only invokes one time, and most linear does not have bias
self.bias = utils.tensor2parameter(self.bias.to(x.dtype))
if hasattr(self, 'forge_online_loras'):
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:

View File

@@ -25,9 +25,11 @@ forge_unet_storage_dtype_options = {
'Automatic': (None, False),
'Automatic (fp16 LoRA)': (None, True),
'bnb-nf4': ('nf4', False),
'bnb-nf4 (fp16 LoRA)': ('nf4', True),
'float8-e4m3fn': (torch.float8_e4m3fn, False),
'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True),
'bnb-fp4': ('fp4', False),
'bnb-fp4 (fp16 LoRA)': ('fp4', True),
'float8-e5m2': (torch.float8_e5m2, False),
'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True),
}
@@ -195,6 +197,7 @@ def refresh_model_loading_parameters():
)
print(f'Model selected: {model_data.forge_loading_parameters}')
processing.need_global_unload = True
return