diff --git a/comfy/model_management.py b/comfy/model_management.py index 101c57c1c..38c3e482b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1214,12 +1214,11 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - raw_tensor = weight._v_tensor - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + v_tensor = weight._v_tensor else: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = raw_tensor + weight._v_tensor = v_tensor weight._v_signature = signature #Send it over v_tensor.copy_(weight, non_blocking=non_blocking) diff --git a/comfy/ops.py b/comfy/ops.py index 9a824fab2..688937e43 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -83,18 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): offload_stream = None xfer_dest = None - cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) if signature is not None: if resident: - xfer_dest = s._v_tensor + weight = s._v_weight + bias = s._v_bias else: xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) - s._v_tensor = xfer_dest if not resident: + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None xfer_source = [ s.weight, s.bias ] @@ -144,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) - weight = params[0] - bias = params[1] + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + if signature is not None: + s._v_weight = weight + s._v_bias = bias + s._v_signature=signature def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) @@ -186,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if s.bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) - s._v_signature=signature #FIXME: weird offload return protocol return weight, bias, (offload_stream, device if signature is not None else None, None)