Get Flux working on Apple Silicon (#1264)

Co-authored-by: Conor Nash <conor@nbs.consulting>
This commit is contained in:
Conor Nash
2024-09-13 15:40:11 +01:00
committed by GitHub
parent cb412b290b
commit 8bd7e0568f

View File

@@ -19,7 +19,10 @@ def attention(q, k, v, pe):
def rope(pos, dim, theta):
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
if pos.device.type == "mps":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
# out = torch.einsum("...n,d->...nd", pos, omega)