mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Tests for half GEMM kernels
This commit is contained in:
36
tests/test_hgemm.py
Normal file
36
tests/test_hgemm.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
import random
|
||||
|
||||
shapes = [ # m, k, n
|
||||
[ 1, 16, 1 ],
|
||||
[ 16, 16, 16 ],
|
||||
[ 8, 256, 32 ],
|
||||
[ 8, 8, 256 ],
|
||||
]
|
||||
|
||||
for i in range(10):
|
||||
shapes.append([random.randint(1, 200), random.randint(1, 16) * 32, random.randint(1, 10)])
|
||||
shapes.append([random.randint(1, 10), random.randint(1, 4) * 32, random.randint(1, 200)])
|
||||
|
||||
for s in shapes:
|
||||
m, k, n = s[0], s[1], s[2]
|
||||
|
||||
print(f" ({m}, {k}) @ ({k}, {n}) -> ({m}, {n}): ".ljust(42), end = "")
|
||||
|
||||
a = torch.randn((m, k), dtype = torch.half, device = "cuda:0")
|
||||
b = torch.randn((k, n), dtype = torch.half, device = "cuda:0")
|
||||
c = torch.empty((m, n), dtype = torch.half, device = "cuda:0")
|
||||
d = torch.empty((m, n), dtype = torch.half, device = "cuda:0")
|
||||
|
||||
ext_c.gemm_half_half_half(a, b, c, 1, 0, False)
|
||||
ext_c.gemm_half_half_half(a, b, d, 1, 0, True)
|
||||
t = torch.matmul(a, b)
|
||||
|
||||
e_cublas = d - c
|
||||
e_torch = t - c
|
||||
diff_cublas = torch.max(torch.abs(e_cublas)).item()
|
||||
diff_torch = torch.max(torch.abs(e_torch)).item()
|
||||
|
||||
print(f"diff vs cuBLAS: {diff_cublas:.3f} diff vs Torch: {diff_torch:.3f}")
|
||||
|
||||
Reference in New Issue
Block a user