mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
import pytest
|
|
import torch
|
|
from exllamav3.ext import exllamav3_ext as ext
|
|
from itertools import pairwise
|
|
|
|
device = "cuda:2"
|
|
page_size = 256
|
|
|
|
cache_dims = [
|
|
[2048, page_size, 16 * 128],
|
|
[1024, page_size, 8 * 128],
|
|
[512, page_size, 8 * 128],
|
|
[256, page_size, 8 * 128],
|
|
[100, page_size, 4 * 128],
|
|
[32, page_size, 4 * 128],
|
|
[32, page_size, 48],
|
|
[2560, page_size, 16],
|
|
]
|
|
|
|
cache_dtypes = [torch.half, torch.float]
|
|
|
|
full_opt = [True, False]
|
|
|
|
@pytest.mark.parametrize("cache_dim", cache_dims)
|
|
@pytest.mark.parametrize("cache_dtype", cache_dtypes)
|
|
@pytest.mark.parametrize("full", full_opt)
|
|
@torch.inference_mode()
|
|
def test_rope(cache_dim, cache_dtype, full):
|
|
|
|
torch.manual_seed(0)
|
|
|
|
num_pages = cache_dim[0]
|
|
cache = torch.randn(cache_dim, device = device, dtype = torch.half)
|
|
order = torch.randperm(num_pages, device = device, dtype = torch.int)
|
|
if not full:
|
|
order = order[:num_pages // 4]
|
|
|
|
order = order.repeat_interleave(2)
|
|
m1 = torch.tensor([-1], device = device, dtype = torch.int)
|
|
order = torch.cat([m1, order, m1], dim = -1)
|
|
if not full:
|
|
order = torch.cat([order, order], dim = -1)
|
|
|
|
ref_cache = cache.clone()
|
|
ref_order = order.tolist()
|
|
|
|
for _ in range(3):
|
|
temp = torch.empty_like(ref_cache[0])
|
|
for i in range(0, len(ref_order), 2):
|
|
a = ref_order[i]
|
|
b = ref_order[i + 1]
|
|
dst = ref_cache[a, ...] if a >= 0 else temp
|
|
src = ref_cache[b, ...] if b >= 0 else temp
|
|
dst.copy_(src)
|
|
|
|
temp = torch.empty_like(cache[0])
|
|
ext.cache_rotate(cache, order, temp)
|
|
|
|
torch.testing.assert_close(cache, ref_cache, rtol = 0, atol = 0)
|
|
|