mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-28 10:54:05 +00:00
ops: implement lora requanting for non QuantizedTensor fp8 (#12668)
Allow non QuantizedTensor layer to set want_requant to get the post lora calculation stochastic cast down to the original input dtype. This is then used by the legacy fp8 Linear implementation to set the compute_dtype to the preferred lora dtype but then want_requant it back down to fp8. This fixes the issue with --fast fp8_matrix_mult is combined with --fast dynamic_vram which doing a lora on an fp8_ non QT model.
This commit is contained in:
19
comfy/ops.py
19
comfy/ops.py
@@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
if (want_requant and len(fns) == 0 or update_weight):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if isinstance(orig, QuantizedTensor):
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
else:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||
if want_requant and len(fns) == 0:
|
||||
x = y
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
@@ -617,7 +615,8 @@ def fp8_linear(self, input):
|
||||
|
||||
if input.ndim != 2:
|
||||
return None
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user