diff --git a/comfy/model_management.py b/comfy/model_management.py index d59d6b354..101c57c1c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1213,9 +1213,13 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + raw_tensor = weight._v_tensor + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + else: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + weight._v_tensor = raw_tensor weight._v_signature = signature #Send it over v_tensor.copy_(weight, non_blocking=non_blocking) diff --git a/comfy/ops.py b/comfy/ops.py index 057535c8c..9a824fab2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,9 +86,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - if signature is not None: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + if signature is not None: + if resident: + xfer_dest = s._v_tensor + else: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._v_tensor = xfer_dest if not resident: cast_dest = None