feat(ara): fix issues on some multi-GPU setups

Co-authored-by: kabachuha <artemkhrapov2001@yandex.ru>
This commit is contained in:
Philipp Emanuel Weidmann
2026-03-07 18:24:31 +05:30
parent 0bb9521fbe
commit d79a443e6f

View File

@@ -570,6 +570,11 @@ class Model:
module_index
]
good_input = good_input.to(matrix.device)
good_output = good_output.to(matrix.device)
bad_input = bad_input.to(matrix.device)
bad_output = bad_output.to(matrix.device)
def objective(matrix: Tensor) -> Tensor:
new_good_output = good_input @ matrix.T
new_bad_output = bad_input @ matrix.T