RoPE: Fix multidim rotation

This commit is contained in:
Ycros
2026-05-01 09:03:32 +00:00
parent 4e587cd19b
commit 993b1cbff5
2 changed files with 228 additions and 83 deletions

View File

@@ -12,6 +12,7 @@
#define ROPESTYLE_NEOX 2
#define ROPESTYLE_NANOCHAT 3
#define MAX_NUM_THREADS 1024
#define MAX_ROTATE_DIMS 4
using bfloat16 = __nv_bfloat16;
using bfloat162 = __nv_bfloat162;
@@ -44,43 +45,64 @@ void rope_kernel
const float llama_4_scaling_beta,
const int llama_4_scaling_original,
bool post_rope_norm,
int position_ids_stride
int position_ids_stride,
int rotate_dims
)
{
// Get position
int batch = blockIdx.y;
int token_pos = blockIdx.x;
int pos = token_pos + position;
if (positions)
pos = token_pos + positions[batch];
else if (position_ids)
pos = position_ids[batch * seq_len + token_pos * position_ids_stride];
// Apply Llama 4 scaling
if (llama_4_scaling_beta > 0.0f)
{
float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original));
attn_factor *= scaling;
}
// Load inv_freq, compute sin/cos
int t = threadIdx.x;
int t_head = threadIdx.y;
float sin;
float cos;
if (!inv_freq_table)
auto get_pos = [&] (int rdim) -> int
{
float fr = inv_freq[t];
float pf = __int2float_rn(pos);
sin = __sinf(fr * pf) * attn_factor;
cos = __cosf(fr * pf) * attn_factor;
}
else
int pos = token_pos + position;
if (positions)
pos = token_pos + positions[batch];
else if (position_ids)
{
int idx = (batch * seq_len + token_pos) * position_ids_stride;
if (position_ids_stride > 1) idx += rdim;
pos = position_ids[idx];
}
return pos;
};
auto get_sincos = [&] (int rdim, float &sin, float &cos)
{
float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t];
sin = __sinf(fr) * attn_factor;
cos = __cosf(fr) * attn_factor;
int pos = get_pos(rdim);
float local_attn_factor = attn_factor;
// Apply Llama 4 scaling
if (llama_4_scaling_beta > 0.0f)
{
float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original));
local_attn_factor *= scaling;
}
if (!inv_freq_table)
{
float fr = inv_freq[t];
float pf = __int2float_rn(pos);
sin = __sinf(fr * pf) * local_attn_factor;
cos = __cosf(fr * pf) * local_attn_factor;
}
else
{
float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t];
sin = __sinf(fr) * local_attn_factor;
cos = __cosf(fr) * local_attn_factor;
}
};
float sin_cache[MAX_ROTATE_DIMS];
float cos_cache[MAX_ROTATE_DIMS];
if (t < partial_head_dim / 2)
{
for (int rdim = 0; rdim < rotate_dims; ++rdim)
{
get_sincos(rdim, sin_cache[rdim], cos_cache[rdim]);
}
}
// Shared buffer
@@ -129,36 +151,43 @@ void rope_kernel
// Apply embeddings
auto apply_rope = [&] ()
{
if (t < partial_head_dim / 2)
for (int rdim = 0; rdim < rotate_dims; ++rdim)
{
if constexpr (rope_mode == ROPESTYLE_NEOX)
int offset = partial_head_dim * rdim;
if (t < partial_head_dim / 2)
{
float v1 = __half2float(sh_head[t]);
float v2 = __half2float(sh_head[t + partial_head_dim / 2]);
float r1 = v1 * cos - v2 * sin;
float r2 = v2 * cos + v1 * sin;
sh_head[t] = __float2half_rn(r1);
sh_head[t + partial_head_dim / 2] = __float2half_rn(r2);
}
else if constexpr (rope_mode == ROPESTYLE_NANOCHAT)
{
float v1 = __half2float(sh_head[t]);
float v2 = __half2float(sh_head[t + partial_head_dim / 2]);
float r1 = v1 * cos + v2 * sin;
float r2 = v2 * cos - v1 * sin;
sh_head[t] = __float2half_rn(r1);
sh_head[t + partial_head_dim / 2] = __float2half_rn(r2);
}
else if constexpr (rope_mode == ROPESTYLE_GPTJ)
{
half2 *tptr = (half2*)(sh_head + t * 2);
half2 v = *tptr;
float v1 = __low2float(v);
float v2 = __high2float(v);
float r1 = v1 * cos - v2 * sin;
float r2 = v2 * cos + v1 * sin;
v = __floats2half2_rn(r1, r2);
*tptr = v;
float sin = sin_cache[rdim];
float cos = cos_cache[rdim];
if constexpr (rope_mode == ROPESTYLE_NEOX)
{
float v1 = __half2float(sh_head[offset + t]);
float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]);
float r1 = v1 * cos - v2 * sin;
float r2 = v2 * cos + v1 * sin;
sh_head[offset + t] = __float2half_rn(r1);
sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2);
}
else if constexpr (rope_mode == ROPESTYLE_NANOCHAT)
{
float v1 = __half2float(sh_head[offset + t]);
float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]);
float r1 = v1 * cos + v2 * sin;
float r2 = v2 * cos - v1 * sin;
sh_head[offset + t] = __float2half_rn(r1);
sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2);
}
else if constexpr (rope_mode == ROPESTYLE_GPTJ)
{
half2 *tptr = (half2*)(sh_head + offset + t * 2);
half2 v = *tptr;
float v1 = __low2float(v);
float v2 = __high2float(v);
float r1 = v1 * cos - v2 * sin;
float r2 = v2 * cos + v1 * sin;
v = __floats2half2_rn(r1, r2);
*tptr = v;
}
}
}
__syncthreads();
@@ -305,6 +334,8 @@ void rope
int head_dim = q.size(3);
int partial_head_dim = inv_freq.size(-1) * 2;
int inv_freq_stride = 0;
TORCH_CHECK(rotate_dims > 0 && rotate_dims <= MAX_ROTATE_DIMS, "rotate_dims out of range");
TORCH_CHECK(rotate_dims == 1 || head_dim == partial_head_dim * rotate_dims, "rotate_dims is inconsistent with inv_freq and head_dim");
const half* q_ptr = (half*) q.data_ptr();
half* out_q_ptr = (half*) out_q.data_ptr();
@@ -336,6 +367,7 @@ void rope
uint32_t* positions_ptr = (uint32_t*) OPTPTR(positions);
uint32_t* position_ids_ptr = (uint32_t*) OPTPTR(position_ids);
int position_ids_stride = 1;
TORCH_CHECK_DTYPE_OPT(positions, kInt);
TORCH_CHECK_DTYPE_OPT(position_ids, kInt);
TORCH_CHECK((positions_ptr != nullptr) + (position_ids_ptr != nullptr) <= 1, "invalid arguments")
@@ -353,6 +385,7 @@ void rope
TORCH_CHECK(rd == 2 || (rd == 3 && position_ids.value().size(-1) == rotate_dims), "position_ids wrong number of dims")
TORCH_CHECK(position_ids.value().size(0) == bsz, "position_ids is incorrect shape");
TORCH_CHECK(position_ids.value().size(1) == seq_len, "position_ids is incorrect shape");
if (rd == 3) position_ids_stride = rotate_dims;
}
void* q_norm_ptr = (void*) OPTPTR(q_norm);
@@ -375,34 +408,26 @@ void rope
int parallel_heads = MIN((MAX_NUM_THREADS / thr), num_heads_q + num_heads_k);
dim3 threads(thr, parallel_heads);
for (int rdim = rotate_dims - 1; rdim >= 0; --rdim)
#define ARGS q_ptr, out_q_ptr, k_ptr, out_k_ptr, inv_freq_ptr, bsz, \
seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \
position_ids_ptr, attn_factor, q_norm_ptr, k_norm_ptr, norm_eps, norm_constant_bias, inv_freq_table, \
inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, post_rope_norm, position_ids_stride, rotate_dims
if (norm_fp16)
{
bool last_dim = rdim == 0;
int offset = partial_head_dim * rdim;
const uint32_t* position_ids_ptr_ = position_ids_ptr;
if (rotate_dims > 1) position_ids_ptr_ += rdim;
#define ARGS q_ptr + offset, out_q_ptr + offset, k_ptr + offset, out_k_ptr + offset, inv_freq_ptr, bsz, \
seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \
position_ids_ptr_, attn_factor, last_dim ? q_norm_ptr : nullptr, \
last_dim ? k_norm_ptr : nullptr, norm_eps, norm_constant_bias, inv_freq_table, \
inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, \
post_rope_norm && last_dim, rotate_dims
if (norm_fp16)
{
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
}
else if (norm_bf16)
{
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
}
else TORCH_CHECK(false, "rope: incorrect norm dtype");
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
}
else if (norm_bf16)
{
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
}
else TORCH_CHECK(false, "rope: incorrect norm dtype");
#undef ARGS
cuda_check(cudaPeekAtLastError());
}

View File

@@ -102,4 +102,124 @@ def test_rope(qk_dim, rope_style, use_norm):
run(0, torch.randint(size = (bsz,), low = 0, high = 49, dtype = torch.int, device = device), None)
# Batched position ids
run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device))
run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device))
@pytest.mark.parametrize("rope_style", rope_styles)
@pytest.mark.parametrize("use_norm", norm_opt)
@pytest.mark.parametrize("in_place", [False, True])
@torch.inference_mode()
def test_rope_multidim(rope_style, use_norm, in_place):
bsz = 2
seq_len = 280
num_heads = 16
head_dim = 72
rotate_dims = 2
partial_head_dim = head_dim // rotate_dims
rope_layer = RoPE(
device = device,
rope_settings = RopeSettings(
rope_theta = 100.0,
head_dim = partial_head_dim,
rope_scaling = None,
max_position_embeddings = 131072,
partial_rotary_factor = 1.0,
rope_style = rope_style,
rotate_dims = rotate_dims,
)
)
def qk():
torch.manual_seed(0)
q_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device)
k_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device)
return q_pr, k_pr
def apply_norm(
x: torch.Tensor,
w: torch.Tensor | None,
eps: float,
constant_bias: float
) -> torch.Tensor:
x = x.float()
var = x.pow(2).mean(dim = -1, keepdim = True) + eps
x = x * torch.rsqrt(var)
if w is not None:
x = x * (w.float() + constant_bias)
return x.half()
def apply_rope_embed(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
sin = sin.unsqueeze(1)
cos = cos.unsqueeze(1)
if rope_style == RopeStyle.NEOX:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
xr = torch.cat((-x2, x1), dim = -1)
else:
x1 = x[..., 0::2]
x2 = x[..., 1::2]
xr = torch.stack((-x2, x1), dim = -1).flatten(-2)
return (x * cos + xr * sin).transpose(1, 2).half()
def apply_multidim_ref(
q: torch.Tensor,
k: torch.Tensor,
position_ids: torch.Tensor,
q_norm: torch.Tensor | None,
k_norm: torch.Tensor | None,
eps: float,
constant_bias: float,
) -> tuple[torch.Tensor, torch.Tensor]:
if q_norm is not None:
q = apply_norm(q, q_norm, eps, constant_bias)
k = apply_norm(k, k_norm, eps, constant_bias)
out_q = []
out_k = []
for rdim in range(rotate_dims):
start = partial_head_dim * rdim
end = start + partial_head_dim
pos = position_ids[:, :, rdim].float()
freqs = torch.einsum("bi,j->bij", pos, rope_layer.inv_freq.float())
sin = freqs.sin() * rope_layer.attn_factor
cos = freqs.cos() * rope_layer.attn_factor
if rope_style == RopeStyle.NEOX:
sin = torch.cat((sin, sin), dim = -1)
cos = torch.cat((cos, cos), dim = -1)
else:
sin = torch.repeat_interleave(sin, 2, dim = -1)
cos = torch.repeat_interleave(cos, 2, dim = -1)
out_q.append(apply_rope_embed(q[..., start : end], sin, cos))
out_k.append(apply_rope_embed(k[..., start : end], sin, cos))
return torch.cat(out_q, dim = -1), torch.cat(out_k, dim = -1)
base = torch.arange(seq_len, dtype = torch.int, device = device)
position_ids = torch.stack((base % 20, base // 20), dim = -1).unsqueeze(0).repeat(bsz, 1, 1)
position_ids[1, :, 0] += 3
position_ids[1, :, 1] += 5
q, k = qk()
eps = 1e-6
constant_bias = 0.0
if use_norm:
torch.manual_seed(1)
norm_q = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0
norm_k = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0
else:
norm_q = None
norm_k = None
q_ref, k_ref = apply_multidim_ref(q, k, position_ids, norm_q, norm_k, eps, constant_bias)
q, k = qk()
q_pre = q.clone()
k_pre = k.clone()
q_out, k_out = rope_layer.apply(q, k, 0, None, position_ids, in_place, norm_q, norm_k, eps, constant_bias)
torch.testing.assert_close(q_out, q_ref, rtol = 3e-3, atol = 3e-3)
torch.testing.assert_close(k_out, k_ref, rtol = 3e-3, atol = 3e-3)
if not in_place:
torch.testing.assert_close(q, q_pre, rtol = 0, atol = 0)
torch.testing.assert_close(k, k_pre, rtol = 0, atol = 0)