Files
exllamav3/tests/test_cache_rotate.py
2025-05-28 22:23:08 +02:00

63 lines
1.8 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import torch
from exllamav3.ext import exllamav3_ext as ext
from itertools import pairwise
device = "cuda:2"
page_size = 256
cache_dims = [
[2048, page_size, 16 * 128],
[1024, page_size, 8 * 128],
[512, page_size, 8 * 128],
[256, page_size, 8 * 128],
[100, page_size, 4 * 128],
[32, page_size, 4 * 128],
[32, page_size, 48],
[2560, page_size, 16],
]
cache_dtypes = [torch.half, torch.float]
full_opt = [True, False]
@pytest.mark.parametrize("cache_dim", cache_dims)
@pytest.mark.parametrize("cache_dtype", cache_dtypes)
@pytest.mark.parametrize("full", full_opt)
@torch.inference_mode()
def test_rope(cache_dim, cache_dtype, full):
torch.manual_seed(0)
num_pages = cache_dim[0]
cache = torch.randn(cache_dim, device = device, dtype = torch.half)
order = torch.randperm(num_pages, device = device, dtype = torch.int)
if not full:
order = order[:num_pages // 4]
order = order.repeat_interleave(2)
m1 = torch.tensor([-1], device = device, dtype = torch.int)
order = torch.cat([m1, order, m1], dim = -1)
if not full:
order = torch.cat([order, order], dim = -1)
ref_cache = cache.clone()
ref_order = order.tolist()
for _ in range(3):
temp = torch.empty_like(ref_cache[0])
for i in range(0, len(ref_order), 2):
a = ref_order[i]
b = ref_order[i + 1]
dst = ref_cache[a, ...] if a >= 0 else temp
src = ref_cache[b, ...] if b >= 0 else temp
dst.copy_(src)
temp = torch.empty_like(cache[0])
ext.cache_rotate(cache, order, temp)
torch.testing.assert_close(cache, ref_cache, rtol = 0, atol = 0)