diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 8c5b0f48..99546a55 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -19,7 +19,10 @@ def attention(q, k, v, pe): def rope(pos, dim, theta): - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + if pos.device.type == "mps": + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim + else: + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta ** scale) # out = torch.einsum("...n,d->...nd", pos, omega)