ops: implement lora requanting for non QuantizedTensor fp8 (#12668)

Allow non QuantizedTensor layer to set want_requant to get the post lora
calculation stochastic cast down to the original input dtype.

This is then used by the legacy fp8 Linear implementation to set the
compute_dtype to the preferred lora dtype but then want_requant it back
down to fp8.

This fixes the issue with --fast fp8_matrix_mult is combined with
--fast dynamic_vram which doing a lora on an fp8_ non QT model.
This commit is contained in:
rattus
2026-02-27 16:05:51 -08:00
committed by GitHub
parent 25ec3d96a3
commit e721e24136

View File

@@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = to_dequant(x, dtype)
if not resident and lowvram_fn is not None:
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(want_requant and len(fns) == 0 or update_weight)):
if (want_requant and len(fns) == 0 or update_weight):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if isinstance(orig, QuantizedTensor):
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
else:
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
if want_requant and len(fns) == 0:
x = y
if update_weight:
orig.copy_(y)
for f in fns:
@@ -617,7 +615,8 @@ def fp8_linear(self, input):
if input.ndim != 2:
return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)