dynamic_vram: Minimize fast path CPU work

Move as much as possible inside the not resident if block and cache
the formed weight and bias rather than the flat intermediates. In
extreme layer weight rates this adds up.
This commit is contained in:
Rattus
2026-02-12 00:24:43 +10:00
parent f7aebddcf6
commit 8423394577
2 changed files with 12 additions and 10 deletions

View File

@@ -1214,12 +1214,11 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
raw_tensor = weight._v_tensor
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = raw_tensor
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)

View File

@@ -83,18 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
xfer_dest = s._v_tensor
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._v_tensor = xfer_dest
if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -144,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -186,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)