diff --git a/exllamav3/exllamav3_ext/rope.cu b/exllamav3/exllamav3_ext/rope.cu index 5edbed0..a521d36 100644 --- a/exllamav3/exllamav3_ext/rope.cu +++ b/exllamav3/exllamav3_ext/rope.cu @@ -12,6 +12,7 @@ #define ROPESTYLE_NEOX 2 #define ROPESTYLE_NANOCHAT 3 #define MAX_NUM_THREADS 1024 +#define MAX_ROTATE_DIMS 4 using bfloat16 = __nv_bfloat16; using bfloat162 = __nv_bfloat162; @@ -44,43 +45,64 @@ void rope_kernel const float llama_4_scaling_beta, const int llama_4_scaling_original, bool post_rope_norm, - int position_ids_stride + int position_ids_stride, + int rotate_dims ) { - // Get position int batch = blockIdx.y; int token_pos = blockIdx.x; - int pos = token_pos + position; - if (positions) - pos = token_pos + positions[batch]; - else if (position_ids) - pos = position_ids[batch * seq_len + token_pos * position_ids_stride]; - - // Apply Llama 4 scaling - if (llama_4_scaling_beta > 0.0f) - { - float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original)); - attn_factor *= scaling; - } - - // Load inv_freq, compute sin/cos int t = threadIdx.x; int t_head = threadIdx.y; - float sin; - float cos; - if (!inv_freq_table) + auto get_pos = [&] (int rdim) -> int { - float fr = inv_freq[t]; - float pf = __int2float_rn(pos); - sin = __sinf(fr * pf) * attn_factor; - cos = __cosf(fr * pf) * attn_factor; - } - else + int pos = token_pos + position; + if (positions) + pos = token_pos + positions[batch]; + else if (position_ids) + { + int idx = (batch * seq_len + token_pos) * position_ids_stride; + if (position_ids_stride > 1) idx += rdim; + pos = position_ids[idx]; + } + return pos; + }; + + auto get_sincos = [&] (int rdim, float &sin, float &cos) { - float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t]; - sin = __sinf(fr) * attn_factor; - cos = __cosf(fr) * attn_factor; + int pos = get_pos(rdim); + float local_attn_factor = attn_factor; + + // Apply Llama 4 scaling + if (llama_4_scaling_beta > 0.0f) + { + float scaling = 1.0f + llama_4_scaling_beta * __logf(1.0f + (float)(pos / llama_4_scaling_original)); + local_attn_factor *= scaling; + } + + if (!inv_freq_table) + { + float fr = inv_freq[t]; + float pf = __int2float_rn(pos); + sin = __sinf(fr * pf) * local_attn_factor; + cos = __cosf(fr * pf) * local_attn_factor; + } + else + { + float fr = inv_freq[batch * inv_freq_stride + pos * partial_head_dim / 2 + t]; + sin = __sinf(fr) * local_attn_factor; + cos = __cosf(fr) * local_attn_factor; + } + }; + + float sin_cache[MAX_ROTATE_DIMS]; + float cos_cache[MAX_ROTATE_DIMS]; + if (t < partial_head_dim / 2) + { + for (int rdim = 0; rdim < rotate_dims; ++rdim) + { + get_sincos(rdim, sin_cache[rdim], cos_cache[rdim]); + } } // Shared buffer @@ -129,36 +151,43 @@ void rope_kernel // Apply embeddings auto apply_rope = [&] () { - if (t < partial_head_dim / 2) + for (int rdim = 0; rdim < rotate_dims; ++rdim) { - if constexpr (rope_mode == ROPESTYLE_NEOX) + int offset = partial_head_dim * rdim; + if (t < partial_head_dim / 2) { - float v1 = __half2float(sh_head[t]); - float v2 = __half2float(sh_head[t + partial_head_dim / 2]); - float r1 = v1 * cos - v2 * sin; - float r2 = v2 * cos + v1 * sin; - sh_head[t] = __float2half_rn(r1); - sh_head[t + partial_head_dim / 2] = __float2half_rn(r2); - } - else if constexpr (rope_mode == ROPESTYLE_NANOCHAT) - { - float v1 = __half2float(sh_head[t]); - float v2 = __half2float(sh_head[t + partial_head_dim / 2]); - float r1 = v1 * cos + v2 * sin; - float r2 = v2 * cos - v1 * sin; - sh_head[t] = __float2half_rn(r1); - sh_head[t + partial_head_dim / 2] = __float2half_rn(r2); - } - else if constexpr (rope_mode == ROPESTYLE_GPTJ) - { - half2 *tptr = (half2*)(sh_head + t * 2); - half2 v = *tptr; - float v1 = __low2float(v); - float v2 = __high2float(v); - float r1 = v1 * cos - v2 * sin; - float r2 = v2 * cos + v1 * sin; - v = __floats2half2_rn(r1, r2); - *tptr = v; + float sin = sin_cache[rdim]; + float cos = cos_cache[rdim]; + + if constexpr (rope_mode == ROPESTYLE_NEOX) + { + float v1 = __half2float(sh_head[offset + t]); + float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]); + float r1 = v1 * cos - v2 * sin; + float r2 = v2 * cos + v1 * sin; + sh_head[offset + t] = __float2half_rn(r1); + sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2); + } + else if constexpr (rope_mode == ROPESTYLE_NANOCHAT) + { + float v1 = __half2float(sh_head[offset + t]); + float v2 = __half2float(sh_head[offset + t + partial_head_dim / 2]); + float r1 = v1 * cos + v2 * sin; + float r2 = v2 * cos - v1 * sin; + sh_head[offset + t] = __float2half_rn(r1); + sh_head[offset + t + partial_head_dim / 2] = __float2half_rn(r2); + } + else if constexpr (rope_mode == ROPESTYLE_GPTJ) + { + half2 *tptr = (half2*)(sh_head + offset + t * 2); + half2 v = *tptr; + float v1 = __low2float(v); + float v2 = __high2float(v); + float r1 = v1 * cos - v2 * sin; + float r2 = v2 * cos + v1 * sin; + v = __floats2half2_rn(r1, r2); + *tptr = v; + } } } __syncthreads(); @@ -305,6 +334,8 @@ void rope int head_dim = q.size(3); int partial_head_dim = inv_freq.size(-1) * 2; int inv_freq_stride = 0; + TORCH_CHECK(rotate_dims > 0 && rotate_dims <= MAX_ROTATE_DIMS, "rotate_dims out of range"); + TORCH_CHECK(rotate_dims == 1 || head_dim == partial_head_dim * rotate_dims, "rotate_dims is inconsistent with inv_freq and head_dim"); const half* q_ptr = (half*) q.data_ptr(); half* out_q_ptr = (half*) out_q.data_ptr(); @@ -336,6 +367,7 @@ void rope uint32_t* positions_ptr = (uint32_t*) OPTPTR(positions); uint32_t* position_ids_ptr = (uint32_t*) OPTPTR(position_ids); + int position_ids_stride = 1; TORCH_CHECK_DTYPE_OPT(positions, kInt); TORCH_CHECK_DTYPE_OPT(position_ids, kInt); TORCH_CHECK((positions_ptr != nullptr) + (position_ids_ptr != nullptr) <= 1, "invalid arguments") @@ -353,6 +385,7 @@ void rope TORCH_CHECK(rd == 2 || (rd == 3 && position_ids.value().size(-1) == rotate_dims), "position_ids wrong number of dims") TORCH_CHECK(position_ids.value().size(0) == bsz, "position_ids is incorrect shape"); TORCH_CHECK(position_ids.value().size(1) == seq_len, "position_ids is incorrect shape"); + if (rd == 3) position_ids_stride = rotate_dims; } void* q_norm_ptr = (void*) OPTPTR(q_norm); @@ -375,34 +408,26 @@ void rope int parallel_heads = MIN((MAX_NUM_THREADS / thr), num_heads_q + num_heads_k); dim3 threads(thr, parallel_heads); - for (int rdim = rotate_dims - 1; rdim >= 0; --rdim) + #define ARGS q_ptr, out_q_ptr, k_ptr, out_k_ptr, inv_freq_ptr, bsz, \ + seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \ + position_ids_ptr, attn_factor, q_norm_ptr, k_norm_ptr, norm_eps, norm_constant_bias, inv_freq_table, \ + inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, post_rope_norm, position_ids_stride, rotate_dims + + if (norm_fp16) { - bool last_dim = rdim == 0; - int offset = partial_head_dim * rdim; - const uint32_t* position_ids_ptr_ = position_ids_ptr; - if (rotate_dims > 1) position_ids_ptr_ += rdim; - - #define ARGS q_ptr + offset, out_q_ptr + offset, k_ptr + offset, out_k_ptr + offset, inv_freq_ptr, bsz, \ - seq_len, num_heads_q, num_heads_k, head_dim, partial_head_dim, position, positions_ptr, \ - position_ids_ptr_, attn_factor, last_dim ? q_norm_ptr : nullptr, \ - last_dim ? k_norm_ptr : nullptr, norm_eps, norm_constant_bias, inv_freq_table, \ - inv_freq_stride, llama_4_scaling_beta, llama_4_scaling_original, \ - post_rope_norm && last_dim, rotate_dims - - if (norm_fp16) - { - if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<<>>(ARGS); - else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<<>>(ARGS); - else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<<>>(ARGS); - } - else if (norm_bf16) - { - if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<<>>(ARGS); - else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<<>>(ARGS); - else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<<>>(ARGS); - } - else TORCH_CHECK(false, "rope: incorrect norm dtype"); + if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<<>>(ARGS); + else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<<>>(ARGS); + else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<<>>(ARGS); } + else if (norm_bf16) + { + if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<<>>(ARGS); + else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<<>>(ARGS); + else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<<>>(ARGS); + } + else TORCH_CHECK(false, "rope: incorrect norm dtype"); + + #undef ARGS cuda_check(cudaPeekAtLastError()); } diff --git a/tests/test_rope.py b/tests/test_rope.py index 9417daa..ffc7f50 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -102,4 +102,124 @@ def test_rope(qk_dim, rope_style, use_norm): run(0, torch.randint(size = (bsz,), low = 0, high = 49, dtype = torch.int, device = device), None) # Batched position ids - run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device)) \ No newline at end of file + run(0, None, torch.randint(size = (bsz, seq_len), low = 0, high = 117, dtype = torch.int, device = device)) + + +@pytest.mark.parametrize("rope_style", rope_styles) +@pytest.mark.parametrize("use_norm", norm_opt) +@pytest.mark.parametrize("in_place", [False, True]) +@torch.inference_mode() +def test_rope_multidim(rope_style, use_norm, in_place): + + bsz = 2 + seq_len = 280 + num_heads = 16 + head_dim = 72 + rotate_dims = 2 + partial_head_dim = head_dim // rotate_dims + + rope_layer = RoPE( + device = device, + rope_settings = RopeSettings( + rope_theta = 100.0, + head_dim = partial_head_dim, + rope_scaling = None, + max_position_embeddings = 131072, + partial_rotary_factor = 1.0, + rope_style = rope_style, + rotate_dims = rotate_dims, + ) + ) + + def qk(): + torch.manual_seed(0) + q_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device) + k_pr = torch.randn((bsz, seq_len, num_heads, head_dim), dtype = torch.half, device = device) + return q_pr, k_pr + + def apply_norm( + x: torch.Tensor, + w: torch.Tensor | None, + eps: float, + constant_bias: float + ) -> torch.Tensor: + x = x.float() + var = x.pow(2).mean(dim = -1, keepdim = True) + eps + x = x * torch.rsqrt(var) + if w is not None: + x = x * (w.float() + constant_bias) + return x.half() + + def apply_rope_embed(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + if rope_style == RopeStyle.NEOX: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + xr = torch.cat((-x2, x1), dim = -1) + else: + x1 = x[..., 0::2] + x2 = x[..., 1::2] + xr = torch.stack((-x2, x1), dim = -1).flatten(-2) + return (x * cos + xr * sin).transpose(1, 2).half() + + def apply_multidim_ref( + q: torch.Tensor, + k: torch.Tensor, + position_ids: torch.Tensor, + q_norm: torch.Tensor | None, + k_norm: torch.Tensor | None, + eps: float, + constant_bias: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + if q_norm is not None: + q = apply_norm(q, q_norm, eps, constant_bias) + k = apply_norm(k, k_norm, eps, constant_bias) + + out_q = [] + out_k = [] + for rdim in range(rotate_dims): + start = partial_head_dim * rdim + end = start + partial_head_dim + pos = position_ids[:, :, rdim].float() + freqs = torch.einsum("bi,j->bij", pos, rope_layer.inv_freq.float()) + sin = freqs.sin() * rope_layer.attn_factor + cos = freqs.cos() * rope_layer.attn_factor + if rope_style == RopeStyle.NEOX: + sin = torch.cat((sin, sin), dim = -1) + cos = torch.cat((cos, cos), dim = -1) + else: + sin = torch.repeat_interleave(sin, 2, dim = -1) + cos = torch.repeat_interleave(cos, 2, dim = -1) + out_q.append(apply_rope_embed(q[..., start : end], sin, cos)) + out_k.append(apply_rope_embed(k[..., start : end], sin, cos)) + + return torch.cat(out_q, dim = -1), torch.cat(out_k, dim = -1) + + base = torch.arange(seq_len, dtype = torch.int, device = device) + position_ids = torch.stack((base % 20, base // 20), dim = -1).unsqueeze(0).repeat(bsz, 1, 1) + position_ids[1, :, 0] += 3 + position_ids[1, :, 1] += 5 + + q, k = qk() + eps = 1e-6 + constant_bias = 0.0 + if use_norm: + torch.manual_seed(1) + norm_q = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0 + norm_k = torch.randn(head_dim, device = device, dtype = torch.half) / 2.0 + else: + norm_q = None + norm_k = None + + q_ref, k_ref = apply_multidim_ref(q, k, position_ids, norm_q, norm_k, eps, constant_bias) + q, k = qk() + q_pre = q.clone() + k_pre = k.clone() + q_out, k_out = rope_layer.apply(q, k, 0, None, position_ids, in_place, norm_q, norm_k, eps, constant_bias) + torch.testing.assert_close(q_out, q_ref, rtol = 3e-3, atol = 3e-3) + torch.testing.assert_close(k_out, k_ref, rtol = 3e-3, atol = 3e-3) + if not in_place: + torch.testing.assert_close(q, q_pre, rtol = 0, atol = 0) + torch.testing.assert_close(k, k_pre, rtol = 0, atol = 0)