diff --git a/src/heretic/model.py b/src/heretic/model.py index 0a4b4fb..898d4df 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -570,6 +570,11 @@ class Model: module_index ] + good_input = good_input.to(matrix.device) + good_output = good_output.to(matrix.device) + bad_input = bad_input.to(matrix.device) + bad_output = bad_output.to(matrix.device) + def objective(matrix: Tensor) -> Tensor: new_good_output = good_input @ matrix.T new_bad_output = bad_input @ matrix.T