Files
ktransformers/kt-kernel/examples/torch_attention.py
2025-10-12 05:13:00 +00:00

319 lines
12 KiB
Python

import math
import os, sys
import time
import subprocess
import platform
import json
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.init as init
from torch import nn
class KDeepSeekV3Cache(nn.Module):
def __init__(
self,
# config: PretrainedConfig,
page_size: int = 256,
kv_lora_rank: int = 128,
k_caches: Optional[torch.Tensor] = None,
dtype=torch.bfloat16,
device=torch.device("cuda:0"),
):
super().__init__()
# self.config = config
self.dtype = dtype
self.device = device
self.kv_lora_rank = kv_lora_rank
self.page_size = page_size
self.v_caches = []
self.k_caches = k_caches
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
k_out = self.k_caches[layer_idx]
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
return k_out
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
page_offset = cache_position % self.page_size
page_idx_local = cache_position // self.page_size
query_ids = torch.zeros_like(cache_position)
for i in range(len(q_indptr) - 1):
start_idx = q_indptr[i]
end_idx = q_indptr[i + 1]
query_ids[start_idx:end_idx] = i
page_idx = torch.zeros_like(page_idx_local)
for i in range(bsz_tensors[0]):
query_id = query_ids[i]
local_block = page_idx_local[i]
start_block = kv_indptr[query_id]
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
page_idx[i] = kv_indices[start_block + local_block]
return page_idx, page_offset
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class DeepseekV2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class DeepseekV2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class DeepseekV3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
# self.max_seq_len_cached = None
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(t, self.inv_freq.to(t.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
print("emb", emb.shape)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.max_seq_len_cached is None: # or seq_len[-1] > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[seq_len].to(dtype=x.dtype),
self.sin_cached[seq_len].to(dtype=x.dtype),
)
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
# Find dim range bounds based on rotations
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def yarn_linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.dim
freq_extra = 1.0 / (
self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
freq_inter = 1.0 / (
self.scaling_factor
* self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32
)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
# 判断 seq_len是否是 tensor
if isinstance(seq_len,torch.Tensor):
t = seq_len
else:
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
)