mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-05-11 16:30:12 +00:00
RoPE: Fix multidim rotation
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#define ROPESTYLE_NEOX 2
|
||||
#define ROPESTYLE_NANOCHAT 3
|
||||
#define MAX_NUM_THREADS 1024
|
||||
#define MAX_ROTATE_DIMS 4
|
||||
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
using bfloat162 = __nv_bfloat162;
|
||||
@@ -44,43 +45,64 @@ void rope_kernel
|
||||
const float llama_4_scaling_beta,
|
||||
const int llama_4_scaling_original,
|
||||
bool post_rope_norm,
|
||||
int position_ids_stride
|
||||
int position_ids_stride,
|
||||
int rotate_dims
|
||||
)
|
||||
{
|
||||
// Get position
|
||||
int batch = blockIdx.y;
|
||||
int token_pos = blockIdx.x;
|
||||
int pos = token_pos + position;
|
||||
if (positions)
|
||||
pos = token_pos + positions[batch];
|
||||
else if (position_ids)
|
||||
pos = position_ids[batch * seq_len + token_pos * position_ids_stride];
|
||||
|
||||
// Apply Llama 4 scaling
|
||||
if (llama_4_scaling_beta > 0.0f)
|
||||
{
|
||||
float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original));
|
||||
attn_factor *= scaling;
|
||||
}
|
||||
|
||||
// Load inv_freq, compute sin/cos
|
||||
int t = threadIdx.x;
|
||||
int t_head = threadIdx.y;
|
||||
|
||||
float sin;
|
||||
float cos;
|
||||
if (!inv_freq_table)
|
||||
auto get_pos = [&] (int rdim) -> int
|
||||
{
|
||||
float fr = inv_freq[t];
|
||||
float pf = __int2float_rn(pos);
|
||||
sin = __sinf(fr * pf) * attn_factor;
|
||||
cos = __cosf(fr * pf) * attn_factor;
|
||||
}
|
||||
else
|
||||
int pos = token_pos + position;
|
||||
if (positions)
|
||||
pos = token_pos + positions[batch];
|
||||
else if (position_ids)
|
||||
{
|
||||
int idx = (batch * seq_len + token_pos) * position_ids_stride;
|
||||
if (position_ids_stride > 1) idx += rdim;
|
||||
pos = position_ids[idx];
|
||||
}
|
||||
return pos;
|
||||
};
|
||||
|
||||
auto get_sincos = [&] (int rdim, float &sin, float &cos)
|
||||
{
|
||||
float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t];
|
||||
sin = __sinf(fr) * attn_factor;
|
||||
cos = __cosf(fr) * attn_factor;
|
||||
int pos = get_pos(rdim);
|
||||
float local_attn_factor = attn_factor;
|
||||
|
||||
// Apply Llama 4 scaling
|
||||
if (llama_4_scaling_beta > 0.0f)
|
||||
{
|
||||
float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original));
|
||||
local_attn_factor *= scaling;
|
||||
}
|
||||
|
||||
if (!inv_freq_table)
|
||||
{
|
||||
float fr = inv_freq[t];
|
||||
float pf = __int2float_rn(pos);
|
||||
sin = __sinf(fr * pf) * local_attn_factor;
|
||||
cos = __cosf(fr * pf) * local_attn_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t];
|
||||
sin = __sinf(fr) * local_attn_factor;
|
||||
cos = __cosf(fr) * local_attn_factor;
|
||||
}
|
||||
};
|
||||
|
||||
float sin_cache[MAX_ROTATE_DIMS];
|
||||
float cos_cache[MAX_ROTATE_DIMS];
|
||||
if (t < partial_head_dim / 2)
|
||||
{
|
||||
for (int rdim = 0; rdim < rotate_dims; ++rdim)
|
||||
{
|
||||
get_sincos(rdim, sin_cache[rdim], cos_cache[rdim]);
|
||||
}
|
||||
}
|
||||
|
||||
// Shared buffer
|
||||
@@ -129,36 +151,43 @@ void rope_kernel
|
||||
// Apply embeddings
|
||||
auto apply_rope = [&] ()
|
||||
{
|
||||
if (t < partial_head_dim / 2)
|
||||
for (int rdim = 0; rdim < rotate_dims; ++rdim)
|
||||
{
|
||||
if constexpr (rope_mode == ROPESTYLE_NEOX)
|
||||
int offset = partial_head_dim * rdim;
|
||||
if (t < partial_head_dim / 2)
|
||||
{
|
||||
float v1 = __half2float(sh_head[t]);
|
||||
float v2 = __half2float(sh_head[t + partial_head_dim / 2]);
|
||||
float r1 = v1 * cos - v2 * sin;
|
||||
float r2 = v2 * cos + v1 * sin;
|
||||
sh_head[t] = __float2half_rn(r1);
|
||||
sh_head[t + partial_head_dim / 2] = __float2half_rn(r2);
|
||||
}
|
||||
else if constexpr (rope_mode == ROPESTYLE_NANOCHAT)
|
||||
{
|
||||
float v1 = __half2float(sh_head[t]);
|
||||
float v2 = __half2float(sh_head[t + partial_head_dim / 2]);
|
||||
float r1 = v1 * cos + v2 * sin;
|
||||
float r2 = v2 * cos - v1 * sin;
|
||||
sh_head[t] = __float2half_rn(r1);
|
||||
sh_head[t + partial_head_dim / 2] = __float2half_rn(r2);
|
||||
}
|
||||
else if constexpr (rope_mode == ROPESTYLE_GPTJ)
|
||||
{
|
||||
half2 *tptr = (half2*)(sh_head + t * 2);
|
||||
half2 v = *tptr;
|
||||
float v1 = __low2float(v);
|
||||
float v2 = __high2float(v);
|
||||
float r1 = v1 * cos - v2 * sin;
|
||||
float r2 = v2 * cos + v1 * sin;
|
||||
v = __floats2half2_rn(r1, r2);
|
||||
*tptr = v;
|
||||
float sin = sin_cache[rdim];
|
||||
float cos = cos_cache[rdim];
|
||||
|
||||
if constexpr (rope_mode == ROPESTYLE_NEOX)
|
||||
{
|
||||
float v1 = __half2float(sh_head[offset + t]);
|
||||
float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]);
|
||||
float r1 = v1 * cos - v2 * sin;
|
||||
float r2 = v2 * cos + v1 * sin;
|
||||
sh_head[offset + t] = __float2half_rn(r1);
|
||||
sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2);
|
||||
}
|
||||
else if constexpr (rope_mode == ROPESTYLE_NANOCHAT)
|
||||
{
|
||||
float v1 = __half2float(sh_head[offset + t]);
|
||||
float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]);
|
||||
float r1 = v1 * cos + v2 * sin;
|
||||
float r2 = v2 * cos - v1 * sin;
|
||||
sh_head[offset + t] = __float2half_rn(r1);
|
||||
sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2);
|
||||
}
|
||||
else if constexpr (rope_mode == ROPESTYLE_GPTJ)
|
||||
{
|
||||
half2 *tptr = (half2*)(sh_head + offset + t * 2);
|
||||
half2 v = *tptr;
|
||||
float v1 = __low2float(v);
|
||||
float v2 = __high2float(v);
|
||||
float r1 = v1 * cos - v2 * sin;
|
||||
float r2 = v2 * cos + v1 * sin;
|
||||
v = __floats2half2_rn(r1, r2);
|
||||
*tptr = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -305,6 +334,8 @@ void rope
|
||||
int head_dim = q.size(3);
|
||||
int partial_head_dim = inv_freq.size(-1) * 2;
|
||||
int inv_freq_stride = 0;
|
||||
TORCH_CHECK(rotate_dims > 0 && rotate_dims <= MAX_ROTATE_DIMS, "rotate_dims out of range");
|
||||
TORCH_CHECK(rotate_dims == 1 || head_dim == partial_head_dim * rotate_dims, "rotate_dims is inconsistent with inv_freq and head_dim");
|
||||
|
||||
const half* q_ptr = (half*) q.data_ptr();
|
||||
half* out_q_ptr = (half*) out_q.data_ptr();
|
||||
@@ -336,6 +367,7 @@ void rope
|
||||
|
||||
uint32_t* positions_ptr = (uint32_t*) OPTPTR(positions);
|
||||
uint32_t* position_ids_ptr = (uint32_t*) OPTPTR(position_ids);
|
||||
int position_ids_stride = 1;
|
||||
TORCH_CHECK_DTYPE_OPT(positions, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(position_ids, kInt);
|
||||
TORCH_CHECK((positions_ptr != nullptr) + (position_ids_ptr != nullptr) <= 1, "invalid arguments")
|
||||
@@ -353,6 +385,7 @@ void rope
|
||||
TORCH_CHECK(rd == 2 || (rd == 3 && position_ids.value().size(-1) == rotate_dims), "position_ids wrong number of dims")
|
||||
TORCH_CHECK(position_ids.value().size(0) == bsz, "position_ids is incorrect shape");
|
||||
TORCH_CHECK(position_ids.value().size(1) == seq_len, "position_ids is incorrect shape");
|
||||
if (rd == 3) position_ids_stride = rotate_dims;
|
||||
}
|
||||
|
||||
void* q_norm_ptr = (void*) OPTPTR(q_norm);
|
||||
@@ -375,34 +408,26 @@ void rope
|
||||
int parallel_heads = MIN((MAX_NUM_THREADS / thr), num_heads_q + num_heads_k);
|
||||
dim3 threads(thr, parallel_heads);
|
||||
|
||||
for (int rdim = rotate_dims - 1; rdim >= 0; --rdim)
|
||||
#define ARGS q_ptr, out_q_ptr, k_ptr, out_k_ptr, inv_freq_ptr, bsz, \
|
||||
seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \
|
||||
position_ids_ptr, attn_factor, q_norm_ptr, k_norm_ptr, norm_eps, norm_constant_bias, inv_freq_table, \
|
||||
inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, post_rope_norm, position_ids_stride, rotate_dims
|
||||
|
||||
if (norm_fp16)
|
||||
{
|
||||
bool last_dim = rdim == 0;
|
||||
int offset = partial_head_dim * rdim;
|
||||
const uint32_t* position_ids_ptr_ = position_ids_ptr;
|
||||
if (rotate_dims > 1) position_ids_ptr_ += rdim;
|
||||
|
||||
#define ARGS q_ptr + offset, out_q_ptr + offset, k_ptr + offset, out_k_ptr + offset, inv_freq_ptr, bsz, \
|
||||
seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \
|
||||
position_ids_ptr_, attn_factor, last_dim ? q_norm_ptr : nullptr, \
|
||||
last_dim ? k_norm_ptr : nullptr, norm_eps, norm_constant_bias, inv_freq_table, \
|
||||
inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, \
|
||||
post_rope_norm && last_dim, rotate_dims
|
||||
|
||||
if (norm_fp16)
|
||||
{
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else if (norm_bf16)
|
||||
{
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else TORCH_CHECK(false, "rope: incorrect norm dtype");
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else if (norm_bf16)
|
||||
{
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else TORCH_CHECK(false, "rope: incorrect norm dtype");
|
||||
|
||||
#undef ARGS
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@@ -102,4 +102,124 @@ def test_rope(qk_dim, rope_style, use_norm):
|
||||
run(0, torch.randint(size = (bsz,), low = 0, high = 49, dtype = torch.int, device = device), None)
|
||||
|
||||
# Batched position ids
|
||||
run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device))
|
||||
run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rope_style", rope_styles)
|
||||
@pytest.mark.parametrize("use_norm", norm_opt)
|
||||
@pytest.mark.parametrize("in_place", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_rope_multidim(rope_style, use_norm, in_place):
|
||||
|
||||
bsz = 2
|
||||
seq_len = 280
|
||||
num_heads = 16
|
||||
head_dim = 72
|
||||
rotate_dims = 2
|
||||
partial_head_dim = head_dim // rotate_dims
|
||||
|
||||
rope_layer = RoPE(
|
||||
device = device,
|
||||
rope_settings = RopeSettings(
|
||||
rope_theta = 100.0,
|
||||
head_dim = partial_head_dim,
|
||||
rope_scaling = None,
|
||||
max_position_embeddings = 131072,
|
||||
partial_rotary_factor = 1.0,
|
||||
rope_style = rope_style,
|
||||
rotate_dims = rotate_dims,
|
||||
)
|
||||
)
|
||||
|
||||
def qk():
|
||||
torch.manual_seed(0)
|
||||
q_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device)
|
||||
k_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device)
|
||||
return q_pr, k_pr
|
||||
|
||||
def apply_norm(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor | None,
|
||||
eps: float,
|
||||
constant_bias: float
|
||||
) -> torch.Tensor:
|
||||
x = x.float()
|
||||
var = x.pow(2).mean(dim = -1, keepdim = True) + eps
|
||||
x = x * torch.rsqrt(var)
|
||||
if w is not None:
|
||||
x = x * (w.float() + constant_bias)
|
||||
return x.half()
|
||||
|
||||
def apply_rope_embed(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
||||
x = x.transpose(1, 2)
|
||||
sin = sin.unsqueeze(1)
|
||||
cos = cos.unsqueeze(1)
|
||||
if rope_style == RopeStyle.NEOX:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
xr = torch.cat((-x2, x1), dim = -1)
|
||||
else:
|
||||
x1 = x[..., 0::2]
|
||||
x2 = x[..., 1::2]
|
||||
xr = torch.stack((-x2, x1), dim = -1).flatten(-2)
|
||||
return (x * cos + xr * sin).transpose(1, 2).half()
|
||||
|
||||
def apply_multidim_ref(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
q_norm: torch.Tensor | None,
|
||||
k_norm: torch.Tensor | None,
|
||||
eps: float,
|
||||
constant_bias: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if q_norm is not None:
|
||||
q = apply_norm(q, q_norm, eps, constant_bias)
|
||||
k = apply_norm(k, k_norm, eps, constant_bias)
|
||||
|
||||
out_q = []
|
||||
out_k = []
|
||||
for rdim in range(rotate_dims):
|
||||
start = partial_head_dim * rdim
|
||||
end = start + partial_head_dim
|
||||
pos = position_ids[:, :, rdim].float()
|
||||
freqs = torch.einsum("bi,j->bij", pos, rope_layer.inv_freq.float())
|
||||
sin = freqs.sin() * rope_layer.attn_factor
|
||||
cos = freqs.cos() * rope_layer.attn_factor
|
||||
if rope_style == RopeStyle.NEOX:
|
||||
sin = torch.cat((sin, sin), dim = -1)
|
||||
cos = torch.cat((cos, cos), dim = -1)
|
||||
else:
|
||||
sin = torch.repeat_interleave(sin, 2, dim = -1)
|
||||
cos = torch.repeat_interleave(cos, 2, dim = -1)
|
||||
out_q.append(apply_rope_embed(q[..., start : end], sin, cos))
|
||||
out_k.append(apply_rope_embed(k[..., start : end], sin, cos))
|
||||
|
||||
return torch.cat(out_q, dim = -1), torch.cat(out_k, dim = -1)
|
||||
|
||||
base = torch.arange(seq_len, dtype = torch.int, device = device)
|
||||
position_ids = torch.stack((base % 20, base // 20), dim = -1).unsqueeze(0).repeat(bsz, 1, 1)
|
||||
position_ids[1, :, 0] += 3
|
||||
position_ids[1, :, 1] += 5
|
||||
|
||||
q, k = qk()
|
||||
eps = 1e-6
|
||||
constant_bias = 0.0
|
||||
if use_norm:
|
||||
torch.manual_seed(1)
|
||||
norm_q = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0
|
||||
norm_k = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0
|
||||
else:
|
||||
norm_q = None
|
||||
norm_k = None
|
||||
|
||||
q_ref, k_ref = apply_multidim_ref(q, k, position_ids, norm_q, norm_k, eps, constant_bias)
|
||||
q, k = qk()
|
||||
q_pre = q.clone()
|
||||
k_pre = k.clone()
|
||||
q_out, k_out = rope_layer.apply(q, k, 0, None, position_ids, in_place, norm_q, norm_k, eps, constant_bias)
|
||||
torch.testing.assert_close(q_out, q_ref, rtol = 3e-3, atol = 3e-3)
|
||||
torch.testing.assert_close(k_out, k_ref, rtol = 3e-3, atol = 3e-3)
|
||||
if not in_place:
|
||||
torch.testing.assert_close(q, q_pre, rtol = 0, atol = 0)
|
||||
torch.testing.assert_close(k, k_pre, rtol = 0, atol = 0)
|
||||
|
||||
Reference in New Issue
Block a user