From 8bd7e0568f4d942f4d17bcc7a6966e92d2b897eb Mon Sep 17 00:00:00 2001 From: Conor Nash Date: Fri, 13 Sep 2024 15:40:11 +0100 Subject: [PATCH] Get Flux working on Apple Silicon (#1264) Co-authored-by: Conor Nash --- backend/nn/flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 8c5b0f48..99546a55 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -19,7 +19,10 @@ def attention(q, k, v, pe): def rope(pos, dim, theta): - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + if pos.device.type == "mps": + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim + else: + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta ** scale) # out = torch.einsum("...n,d->...nd", pos, omega)