Files
exllamav2/tests/test_hgemm.py
2023-11-25 11:54:53 +01:00

37 lines
1.2 KiB
Python

import torch
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
import random
shapes = [ # m, k, n
[ 1, 16, 1 ],
[ 16, 16, 16 ],
[ 8, 256, 32 ],
[ 8, 8, 256 ],
]
for i in range(10):
shapes.append([random.randint(1, 200), random.randint(1, 16) * 32, random.randint(1, 10)])
shapes.append([random.randint(1, 10), random.randint(1, 4) * 32, random.randint(1, 200)])
for s in shapes:
m, k, n = s[0], s[1], s[2]
print(f" ({m}, {k}) @ ({k}, {n}) -> ({m}, {n}): ".ljust(42), end = "")
a = torch.randn((m, k), dtype = torch.half, device = "cuda:0")
b = torch.randn((k, n), dtype = torch.half, device = "cuda:0")
c = torch.empty((m, n), dtype = torch.half, device = "cuda:0")
d = torch.empty((m, n), dtype = torch.half, device = "cuda:0")
ext_c.gemm_half_half_half(a, b, c, 1, 0, False)
ext_c.gemm_half_half_half(a, b, d, 1, 0, True)
t = torch.matmul(a, b)
e_cublas = d - c
e_torch = t - c
diff_cublas = torch.max(torch.abs(e_cublas)).item()
diff_torch = torch.max(torch.abs(e_torch)).item()
print(f"diff vs cuBLAS: {diff_cublas:.3f} diff vs Torch: {diff_torch:.3f}")