mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
319 lines
12 KiB
Python
319 lines
12 KiB
Python
|
|
import math
|
|
import os, sys
|
|
import time
|
|
import subprocess
|
|
import platform
|
|
import json
|
|
from typing import Any, Dict, Optional, Tuple
|
|
import torch
|
|
import torch.nn.init as init
|
|
from torch import nn
|
|
|
|
class KDeepSeekV3Cache(nn.Module):
|
|
def __init__(
|
|
self,
|
|
# config: PretrainedConfig,
|
|
page_size: int = 256,
|
|
kv_lora_rank: int = 128,
|
|
k_caches: Optional[torch.Tensor] = None,
|
|
dtype=torch.bfloat16,
|
|
device=torch.device("cuda:0"),
|
|
|
|
):
|
|
super().__init__()
|
|
# self.config = config
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.page_size = page_size
|
|
self.v_caches = []
|
|
self.k_caches = k_caches
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
|
|
page_idx: torch.Tensor,
|
|
page_offset: torch.Tensor,
|
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
|
|
|
Parameters:
|
|
key_states (`torch.Tensor`):
|
|
The new key states to cache.
|
|
value_states (`torch.Tensor`):
|
|
The new value states to cache.
|
|
layer_idx (`int`):
|
|
The index of the layer to cache the states for.
|
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
|
to know how where to write in the cache.
|
|
|
|
Return:
|
|
A tuple containing the updated key and value states.
|
|
"""
|
|
k_out = self.k_caches[layer_idx]
|
|
|
|
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
|
|
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
|
|
return k_out
|
|
|
|
|
|
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
|
page_offset = cache_position % self.page_size
|
|
page_idx_local = cache_position // self.page_size
|
|
query_ids = torch.zeros_like(cache_position)
|
|
for i in range(len(q_indptr) - 1):
|
|
start_idx = q_indptr[i]
|
|
end_idx = q_indptr[i + 1]
|
|
query_ids[start_idx:end_idx] = i
|
|
page_idx = torch.zeros_like(page_idx_local)
|
|
for i in range(bsz_tensors[0]):
|
|
query_id = query_ids[i]
|
|
local_block = page_idx_local[i]
|
|
start_block = kv_indptr[query_id]
|
|
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
|
page_idx[i] = kv_indices[start_block + local_block]
|
|
|
|
return page_idx, page_offset
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`):
|
|
Deprecated and unused.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
b, h, s, d = q.shape
|
|
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
|
b, h, s, d = k.shape
|
|
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class DeepseekV2RMSNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
"""
|
|
DeepseekV2RMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return (self.weight * hidden_states).to(input_dtype)
|
|
|
|
|
|
class DeepseekV2RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
|
super().__init__()
|
|
self.scaling_factor = scaling_factor
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
# For BC we register cos and sin cached
|
|
self.max_seq_len_cached = max_position_embeddings
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids):
|
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
# Force float32 since bfloat16 loses precision on long contexts
|
|
# See https://github.com/huggingface/transformers/pull/29285
|
|
device_type = x.device.type
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False):
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
class DeepseekV3RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
inv_freq = 1.0 / (
|
|
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
|
)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
# Build here to make `torch.jit.trace` work.
|
|
self._set_cos_sin_cache(
|
|
seq_len=max_position_embeddings,
|
|
device=self.inv_freq.device,
|
|
dtype=torch.get_default_dtype(),
|
|
)
|
|
# self.max_seq_len_cached = None
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
t = torch.arange(
|
|
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
|
)
|
|
|
|
freqs = torch.outer(t, self.inv_freq.to(t.device))
|
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
print("emb", emb.shape)
|
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
|
def forward(self, x, seq_len=None):
|
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
if self.max_seq_len_cached is None: # or seq_len[-1] > self.max_seq_len_cached:
|
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
|
|
return (
|
|
self.cos_cached[seq_len].to(dtype=x.dtype),
|
|
self.sin_cached[seq_len].to(dtype=x.dtype),
|
|
)
|
|
|
|
# Inverse dim formula to find dim based on number of rotations
|
|
def yarn_find_correction_dim(
|
|
num_rotations, dim, base=10000, max_position_embeddings=2048
|
|
):
|
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
|
2 * math.log(base)
|
|
)
|
|
|
|
|
|
# Find dim range bounds based on rotations
|
|
def yarn_find_correction_range(
|
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
|
):
|
|
low = math.floor(
|
|
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
|
)
|
|
high = math.ceil(
|
|
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
|
)
|
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
|
|
def yarn_linear_ramp_mask(min, max, dim):
|
|
if min == max:
|
|
max += 0.001 # Prevent singularity
|
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
return ramp_func
|
|
|
|
def yarn_get_mscale(scale=1, mscale=1):
|
|
if scale <= 1:
|
|
return 1.0
|
|
return 0.1 * mscale * math.log(scale) + 1.0
|
|
|
|
class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_position_embeddings=2048,
|
|
base=10000,
|
|
device=None,
|
|
scaling_factor=1.0,
|
|
original_max_position_embeddings=4096,
|
|
beta_fast=32,
|
|
beta_slow=1,
|
|
mscale=1,
|
|
mscale_all_dim=0,
|
|
):
|
|
self.scaling_factor = scaling_factor
|
|
self.original_max_position_embeddings = original_max_position_embeddings
|
|
self.beta_fast = beta_fast
|
|
self.beta_slow = beta_slow
|
|
self.mscale = mscale
|
|
self.mscale_all_dim = mscale_all_dim
|
|
super().__init__(dim, max_position_embeddings, base, device)
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
dim = self.dim
|
|
|
|
freq_extra = 1.0 / (
|
|
self.base
|
|
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
|
)
|
|
freq_inter = 1.0 / (
|
|
self.scaling_factor
|
|
* self.base
|
|
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
|
)
|
|
|
|
low, high = yarn_find_correction_range(
|
|
self.beta_fast,
|
|
self.beta_slow,
|
|
dim,
|
|
self.base,
|
|
self.original_max_position_embeddings,
|
|
)
|
|
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
|
device=device, dtype=torch.float32
|
|
)
|
|
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
# 判断 seq_len是否是 tensor
|
|
if isinstance(seq_len,torch.Tensor):
|
|
t = seq_len
|
|
else:
|
|
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
|
|
_mscale = float(
|
|
yarn_get_mscale(self.scaling_factor, self.mscale)
|
|
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
|
)
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self.register_buffer(
|
|
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
|
|
)
|
|
self.register_buffer(
|
|
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
|
|
)
|
|
|